mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 03:58:22 +00:00
Initial support for the stable audio open model.
This commit is contained in:
@@ -6,6 +6,7 @@ from . import sd1_clip
|
||||
from . import sd2_clip
|
||||
from . import sdxl_clip
|
||||
from . import sd3_clip
|
||||
from . import sa_t5
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@@ -524,7 +525,35 @@ class SD3(supported_models_base.BASE):
|
||||
|
||||
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
|
||||
|
||||
class StableAudio(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"audio_model": "dit1.0",
|
||||
}
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]
|
||||
sampling_settings = {"sigma_max": 500.0, "sigma_min": 0.03}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.StableAudio1
|
||||
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
vae_key_prefix = ["pretransform.model."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
seconds_start_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_start.": ""}, filter_keys=True)
|
||||
seconds_total_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_total.": ""}, filter_keys=True)
|
||||
return model_base.StableAudio1(self, seconds_start_embedder_weights=seconds_start_sd, seconds_total_embedder_weights=seconds_total_sd, device=device)
|
||||
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
for k in list(state_dict.keys()):
|
||||
if k.endswith(".cross_attend_norm.beta") or k.endswith(".ff_norm.beta") or k.endswith(".pre_norm.beta"): #These weights are all zero
|
||||
state_dict.pop(k)
|
||||
return state_dict
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)
|
||||
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
Reference in New Issue
Block a user