Support properly saving CosXL checkpoints.

This commit is contained in:
comfyanonymous
2024-04-08 00:36:22 -04:00
parent d644b6bcd8
commit 30abc324c2
2 changed files with 14 additions and 2 deletions

View File

@@ -600,7 +600,7 @@ def load_unet(unet_path):
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None):
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
clip_sd = None
load_models = [model]
if clip is not None:
@@ -610,4 +610,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
model_management.load_models_gpu(load_models)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
for k in extra_keys:
sd[k] = extra_keys[k]
comfy.utils.save_torch_file(sd, output_path, metadata=metadata)