From de142eaad5818cf4e448d8edc479c89e9b59aff0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 9 Jun 2023 12:24:24 -0400 Subject: [PATCH] Simpler base model code. --- comfy/diffusers_load.py | 28 ++-------------- comfy/model_base.py | 66 +++++++++++++++++++++++++++++++++++++ comfy/samplers.py | 71 ++++++++++++++++++++++------------------ comfy/sd.py | 72 +++++++++++++++++++++++++++++++---------- 4 files changed, 163 insertions(+), 74 deletions(-) create mode 100644 comfy/model_base.py diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index 43877fb8..f494f1d3 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -4,7 +4,7 @@ import yaml import folder_paths from comfy.ldm.util import instantiate_from_config -from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE +from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint import os.path as osp import re import torch @@ -84,28 +84,4 @@ def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, emb # Put together new checkpoint sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} - clip = None - vae = None - - class WeightsLoader(torch.nn.Module): - pass - - w = WeightsLoader() - load_state_dict_to = [] - if output_vae: - vae = VAE(scale_factor=scale_factor, config=vae_config) - w.first_stage_model = vae.first_stage_model - load_state_dict_to = [w] - - if output_clip: - clip = CLIP(config=clip_config, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - load_state_dict_to = [w] - - model = instantiate_from_config(config["model"]) - model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) - - if fp16: - model = model.half() - - return ModelPatcher(model), clip, vae + return load_checkpoint(embedding_directory=embedding_directory, state_dict=sd, config=config) diff --git a/comfy/model_base.py b/comfy/model_base.py new file mode 100644 index 00000000..7370c19f --- /dev/null +++ b/comfy/model_base.py @@ -0,0 +1,66 @@ +import torch +from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel +from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation +from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule +import numpy as np + +class BaseModel(torch.nn.Module): + def __init__(self, unet_config, v_prediction=False): + super().__init__() + + self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) + self.diffusion_model = UNetModel(**unet_config) + self.v_prediction = v_prediction + if self.v_prediction: + self.parameterization = "v" + else: + self.parameterization = "eps" + if "adm_in_channels" in unet_config: + self.adm_channels = unet_config["adm_in_channels"] + else: + self.adm_channels = 0 + print("v_prediction", v_prediction) + print("adm", self.adm_channels) + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if given_betas is not None: + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + + self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) + self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) + self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) + + def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}): + if c_concat is not None: + xc = torch.cat([x] + c_concat, dim=1) + else: + xc = x + context = torch.cat(c_crossattn, 1) + return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options) + + def get_dtype(self): + return self.diffusion_model.dtype + + def is_adm(self): + return self.adm_channels > 0 + +class SD21UNCLIP(BaseModel): + def __init__(self, unet_config, noise_aug_config, v_prediction=True): + super().__init__(unet_config, v_prediction) + self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config) + +class SDInpaint(BaseModel): + def __init__(self, unet_config, v_prediction=False): + super().__init__(unet_config, v_prediction) + self.concat_keys = ("mask", "masked_image") diff --git a/comfy/samplers.py b/comfy/samplers.py index 1fb928f8..a33d150d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -248,7 +248,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con c['transformer_options'] = transformer_options - output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) + output = model_function(input_x, timestep_, **c).chunk(batch_chunks) del input_x model_management.throw_exception_if_processing_interrupted() @@ -460,36 +460,42 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): uncond[temp[1]] = [o[0], n] -def encode_adm(noise_augmentor, conds, batch_size, device): +def encode_adm(conds, batch_size, device, noise_augmentor=None): for t in range(len(conds)): x = conds[t] - if 'adm' in x[1]: - adm_inputs = [] - weights = [] - noise_aug = [] - adm_in = x[1]["adm"] - for adm_c in adm_in: - adm_cond = adm_c[0].image_embeds - weight = adm_c[1] - noise_augment = adm_c[2] - noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) - adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight - weights.append(weight) - noise_aug.append(noise_augment) - adm_inputs.append(adm_out) + adm_out = None + if noise_augmentor is not None: + if 'adm' in x[1]: + adm_inputs = [] + weights = [] + noise_aug = [] + adm_in = x[1]["adm"] + for adm_c in adm_in: + adm_cond = adm_c[0].image_embeds + weight = adm_c[1] + noise_augment = adm_c[2] + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight + weights.append(weight) + noise_aug.append(noise_augment) + adm_inputs.append(adm_out) - if len(noise_aug) > 1: - adm_out = torch.stack(adm_inputs).sum(0) - #TODO: add a way to control this - noise_augment = 0.05 - noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) - adm_out = torch.cat((c_adm, noise_level_emb), 1) + if len(noise_aug) > 1: + adm_out = torch.stack(adm_inputs).sum(0) + #TODO: add a way to control this + noise_augment = 0.05 + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) + else: + adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) else: - adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) - x[1] = x[1].copy() - x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size) + if 'adm' in x[1]: + adm_out = x[1]["adm"].to(device) + if adm_out is not None: + x[1] = x[1].copy() + x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size) return conds @@ -591,14 +597,17 @@ class KSampler: apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) - if self.model.model.diffusion_model.dtype == torch.float16: + if self.model.get_dtype() == torch.float16: precision_scope = torch.autocast else: precision_scope = contextlib.nullcontext - if hasattr(self.model, 'noise_augmentor'): #unclip - positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device) - negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device) + if self.model.is_adm(): + noise_augmentor = None + if hasattr(self.model, 'noise_augmentor'): #unclip + noise_augmentor = self.model.noise_augmentor + positive = encode_adm(positive, noise.shape[0], self.device, noise_augmentor) + negative = encode_adm(negative, noise.shape[0], self.device, noise_augmentor) extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} diff --git a/comfy/sd.py b/comfy/sd.py index 04eaaa9f..3747f53b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -15,8 +15,15 @@ from . import utils from . import clip_vision from . import gligen from . import diffusers_convert +from . import model_base def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): + replace_prefix = {"model.diffusion_model.": "diffusion_model."} + for rp in replace_prefix: + replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), sd.keys()))) + for x in replace: + sd[x[1]] = sd.pop(x[0]) + m, u = model.load_state_dict(sd, strict=False) k = list(sd.keys()) @@ -182,7 +189,7 @@ def model_lora_keys(model, key_map={}): counter = 0 for b in range(12): - tk = "model.diffusion_model.input_blocks.{}.1".format(b) + tk = "diffusion_model.input_blocks.{}.1".format(b) up_counter = 0 for c in LORA_UNET_MAP_ATTENTIONS: k = "{}.{}.weight".format(tk, c) @@ -193,13 +200,13 @@ def model_lora_keys(model, key_map={}): if up_counter >= 4: counter += 1 for c in LORA_UNET_MAP_ATTENTIONS: - k = "model.diffusion_model.middle_block.1.{}.weight".format(c) + k = "diffusion_model.middle_block.1.{}.weight".format(c) if k in sdk: lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c]) key_map[lora_key] = k counter = 3 for b in range(12): - tk = "model.diffusion_model.output_blocks.{}.1".format(b) + tk = "diffusion_model.output_blocks.{}.1".format(b) up_counter = 0 for c in LORA_UNET_MAP_ATTENTIONS: k = "{}.{}.weight".format(tk, c) @@ -223,7 +230,7 @@ def model_lora_keys(model, key_map={}): ds_counter = 0 counter = 0 for b in range(12): - tk = "model.diffusion_model.input_blocks.{}.0".format(b) + tk = "diffusion_model.input_blocks.{}.0".format(b) key_in = False for c in LORA_UNET_MAP_RESNET: k = "{}.{}.weight".format(tk, c) @@ -242,7 +249,7 @@ def model_lora_keys(model, key_map={}): counter = 0 for b in range(3): - tk = "model.diffusion_model.middle_block.{}".format(b) + tk = "diffusion_model.middle_block.{}".format(b) key_in = False for c in LORA_UNET_MAP_RESNET: k = "{}.{}.weight".format(tk, c) @@ -256,7 +263,7 @@ def model_lora_keys(model, key_map={}): counter = 0 us_counter = 0 for b in range(12): - tk = "model.diffusion_model.output_blocks.{}.0".format(b) + tk = "diffusion_model.output_blocks.{}.0".format(b) key_in = False for c in LORA_UNET_MAP_RESNET: k = "{}.{}.weight".format(tk, c) @@ -332,7 +339,7 @@ class ModelPatcher: patch_list[i] = patch_list[i].to(device) def model_dtype(self): - return self.model.diffusion_model.dtype + return self.model.get_dtype() def add_patches(self, patches, strength=1.0): p = {} @@ -764,7 +771,7 @@ def load_controlnet(ckpt_path, model=None): for x in controlnet_data: c_m = "control_model." if x.startswith(c_m): - sd_key = "model.diffusion_model.{}".format(x[len(c_m):]) + sd_key = "diffusion_model.{}".format(x[len(c_m):]) if sd_key in model_sd: cd = controlnet_data[x] cd += model_sd[sd_key].type(cd.dtype).to(cd.device) @@ -931,9 +938,10 @@ def load_gligen(ckpt_path): model = model.half() return model -def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): - with open(config_path, 'r') as stream: - config = yaml.safe_load(stream) +def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): + if config is None: + with open(config_path, 'r') as stream: + config = yaml.safe_load(stream) model_config_params = config['model']['params'] clip_config = model_config_params['cond_stage_config'] scale_factor = model_config_params['scale_factor'] @@ -942,8 +950,19 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e fp16 = False if "unet_config" in model_config_params: if "params" in model_config_params["unet_config"]: - if "use_fp16" in model_config_params["unet_config"]["params"]: - fp16 = model_config_params["unet_config"]["params"]["use_fp16"] + unet_config = model_config_params["unet_config"]["params"] + if "use_fp16" in unet_config: + fp16 = unet_config["use_fp16"] + + noise_aug_config = None + if "noise_aug_config" in model_config_params: + noise_aug_config = model_config_params["noise_aug_config"] + + v_prediction = False + + if "parameterization" in model_config_params: + if model_config_params["parameterization"] == "v": + v_prediction = True clip = None vae = None @@ -963,9 +982,16 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e w.cond_stage_model = clip.cond_stage_model load_state_dict_to = [w] - model = instantiate_from_config(config["model"]) - sd = utils.load_torch_file(ckpt_path) - model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) + if config['model']["target"].endswith("LatentInpaintDiffusion"): + model = model_base.SDInpaint(unet_config, v_prediction=v_prediction) + elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): + model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction) + else: + model = model_base.BaseModel(unet_config, v_prediction=v_prediction) + + if state_dict is None: + state_dict = utils.load_torch_file(ckpt_path) + model = load_model_weights(model, state_dict, verbose=False, load_state_dict_to=load_state_dict_to) if fp16: model = model.half() @@ -1073,16 +1099,20 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} + unclip_model = False + inpaint_model = False if noise_aug_config is not None: #SD2.x unclip model sd_config["noise_aug_config"] = noise_aug_config sd_config["image_size"] = 96 sd_config["embedding_dropout"] = 0.25 sd_config["conditioning_key"] = 'crossattn-adm' + unclip_model = True model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion" elif unet_config["in_channels"] > 4: #inpainting model sd_config["conditioning_key"] = "hybrid" sd_config["finetune_keys"] = None model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion" + inpaint_model = True else: sd_config["conditioning_key"] = "crossattn" @@ -1096,13 +1126,21 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o unet_config["num_classes"] = "sequential" unet_config["adm_in_channels"] = sd[unclip].shape[1] + v_prediction = False if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias" out = sd[k] if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. + v_prediction = True sd_config["parameterization"] = 'v' - model = instantiate_from_config(model_config) + if inpaint_model: + model = model_base.SDInpaint(unet_config, v_prediction=v_prediction) + elif unclip_model: + model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction) + else: + model = model_base.BaseModel(unet_config, v_prediction=v_prediction) + model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) if fp16: