Support saving stable audio checkpoint that can be loaded back.

This commit is contained in:
comfyanonymous
2024-06-27 11:06:52 -04:00
parent 5ff3d4eb3a
commit 8ceb5a02a3
3 changed files with 14 additions and 2 deletions

View File

@@ -543,13 +543,16 @@ class StableAudio(supported_models_base.BASE):
seconds_total_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_total.": ""}, filter_keys=True)
return model_base.StableAudio1(self, seconds_start_embedder_weights=seconds_start_sd, seconds_total_embedder_weights=seconds_total_sd, device=device)
def process_unet_state_dict(self, state_dict):
for k in list(state_dict.keys()):
if k.endswith(".cross_attend_norm.beta") or k.endswith(".ff_norm.beta") or k.endswith(".pre_norm.beta"): #These weights are all zero
state_dict.pop(k)
return state_dict
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model.model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)