mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 12:06:23 +00:00
Support properly saving CosXL checkpoints.
This commit is contained in:
@@ -600,7 +600,7 @@ def load_unet(unet_path):
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return model
|
||||
|
||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None):
|
||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||
clip_sd = None
|
||||
load_models = [model]
|
||||
if clip is not None:
|
||||
@@ -610,4 +610,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
||||
model_management.load_models_gpu(load_models)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
||||
for k in extra_keys:
|
||||
sd[k] = extra_keys[k]
|
||||
|
||||
comfy.utils.save_torch_file(sd, output_path, metadata=metadata)
|
||||
|
Reference in New Issue
Block a user