Add CheckpointSave node to save checkpoints.

The created checkpoints contain workflow metadata that can be loaded by
dragging them on top of the UI or loading them with the "Load" button.

Checkpoints will be saved in fp16 or fp32 depending on the format ComfyUI
is using for inference on your hardware. To force fp32 use: --force-fp32

Anything that patches the model weights like merging or loras will be
saved.

The output directory is currently set to: output/checkpoints but that might
change in the future.
This commit is contained in:
comfyanonymous
2023-06-26 12:21:07 -04:00
parent b72a7a835a
commit 9b93b920be
12 changed files with 147 additions and 13 deletions

View File

@@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import numpy as np
from . import utils
class BaseModel(torch.nn.Module):
def __init__(self, model_config, v_prediction=False):
@@ -11,6 +12,7 @@ class BaseModel(torch.nn.Module):
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config)
self.v_prediction = v_prediction
@@ -83,6 +85,16 @@ class BaseModel(torch.nn.Module):
def process_latent_out(self, latent):
return self.latent_format.process_out(latent)
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16:
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
class SD21UNCLIP(BaseModel):
def __init__(self, model_config, noise_aug_config, v_prediction=True):