AuraFlow model implementation.

This commit is contained in:
comfyanonymous
2024-07-11 16:51:06 -04:00
parent f45157e3ac
commit 9f291d75b3
12 changed files with 1744 additions and 2 deletions

View File

@@ -7,6 +7,7 @@ from . import sd2_clip
from . import sdxl_clip
from . import sd3_clip
from . import sa_t5
import comfy.text_encoders.aura_t5
from . import supported_models_base
from . import latent_formats
@@ -556,7 +557,28 @@ class StableAudio(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)
class AuraFlow(supported_models_base.BASE):
unet_config = {
"cond_seq_dim": 2048,
}
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio]
sampling_settings = {
"multiplier": 1.0,
}
unet_extra_config = {}
latent_format = latent_formats.SDXL
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.AuraFlow(self, device=device)
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow]
models += [SVD_img2vid]