mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 12:06:23 +00:00
Support saving stable audio checkpoint that can be loaded back.
This commit is contained in:
@@ -627,3 +627,12 @@ class StableAudio1(BaseModel):
|
||||
cross_attn = torch.cat([cross_attn.to(device), seconds_start_embed.repeat((cross_attn.shape[0], 1, 1)), seconds_total_embed.repeat((cross_attn.shape[0], 1, 1))], dim=1)
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
|
||||
for k in d:
|
||||
s = d[k]
|
||||
for l in s:
|
||||
sd["{}{}".format(k, l)] = s[l]
|
||||
return sd
|
||||
|
Reference in New Issue
Block a user