Add support for unCLIP SD2.x models.

See _for_testing/unclip in the UI for the new nodes.

unCLIPCheckpointLoader is used to load them.

unCLIPConditioning is used to add the image cond and takes as input a
CLIPVisionEncode output which has been moved to the conditioning section.
This commit is contained in:
comfyanonymous
2023-04-01 23:19:15 -04:00
parent 0d972b85e6
commit 809bcc8ceb
17 changed files with 593 additions and 113 deletions

View File

@@ -12,20 +12,7 @@ from .cldm import cldm
from .t2i_adapter import adapter
from . import utils
def load_torch_file(ckpt):
if ckpt.lower().endswith(".safetensors"):
import safetensors.torch
sd = safetensors.torch.load_file(ckpt, device="cpu")
else:
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
return sd
from . import clip_vision
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
m, u = model.load_state_dict(sd, strict=False)
@@ -53,30 +40,7 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
if x in sd:
sd[keys_to_replace[x]] = sd.pop(x)
resblock_to_replace = {
"ln_1": "layer_norm1",
"ln_2": "layer_norm2",
"mlp.c_fc": "mlp.fc1",
"mlp.c_proj": "mlp.fc2",
"attn.out_proj": "self_attn.out_proj",
}
for resblock in range(24):
for x in resblock_to_replace:
for y in ["weight", "bias"]:
k = "cond_stage_model.model.transformer.resblocks.{}.{}.{}".format(resblock, x, y)
k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, resblock_to_replace[x], y)
if k in sd:
sd[k_to] = sd.pop(k)
for y in ["weight", "bias"]:
k_from = "cond_stage_model.model.transformer.resblocks.{}.attn.in_proj_{}".format(resblock, y)
if k_from in sd:
weights = sd.pop(k_from)
for x in range(3):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, p[x], y)
sd[k_to] = weights[1024*x:1024*(x + 1)]
sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24)
for x in load_state_dict_to:
x.load_state_dict(sd, strict=False)
@@ -123,7 +87,7 @@ LORA_UNET_MAP_RESNET = {
}
def load_lora(path, to_load):
lora = load_torch_file(path)
lora = utils.load_torch_file(path)
patch_dict = {}
loaded_keys = set()
for x in to_load:
@@ -599,7 +563,7 @@ class ControlNet:
return out
def load_controlnet(ckpt_path, model=None):
controlnet_data = load_torch_file(ckpt_path)
controlnet_data = utils.load_torch_file(ckpt_path)
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
pth = False
sd2 = False
@@ -793,7 +757,7 @@ class StyleModel:
def load_style_model(ckpt_path):
model_data = load_torch_file(ckpt_path)
model_data = utils.load_torch_file(ckpt_path)
keys = model_data.keys()
if "style_embedding" in keys:
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
@@ -804,7 +768,7 @@ def load_style_model(ckpt_path):
def load_clip(ckpt_path, embedding_directory=None):
clip_data = load_torch_file(ckpt_path)
clip_data = utils.load_torch_file(ckpt_path)
config = {}
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
@@ -847,7 +811,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
load_state_dict_to = [w]
model = instantiate_from_config(config["model"])
sd = load_torch_file(ckpt_path)
sd = utils.load_torch_file(ckpt_path)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
@@ -856,10 +820,11 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
return (ModelPatcher(model), clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
sd = load_torch_file(ckpt_path)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
sd = utils.load_torch_file(ckpt_path)
sd_keys = sd.keys()
clip = None
clipvision = None
vae = None
fp16 = model_management.should_use_fp16()
@@ -884,6 +849,29 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
clipvision_key = "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight"
noise_aug_config = None
if clipvision_key in sd_keys:
size = sd[clipvision_key].shape[1]
if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd)
noise_aug_key = "noise_augmentor.betas"
if noise_aug_key in sd_keys:
noise_aug_config = {}
params = {}
noise_schedule_config = {}
noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0]
noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2"
params["noise_schedule_config"] = noise_schedule_config
noise_aug_config['target'] = "ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation"
if size == 1280: #h
params["timestep_dim"] = 1024
elif size == 1024: #l
params["timestep_dim"] = 768
noise_aug_config['params'] = params
sd_config = {
"linear_start": 0.00085,
"linear_end": 0.012,
@@ -932,7 +920,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
if unet_config["in_channels"] > 4: #inpainting model
if noise_aug_config is not None: #SD2.x unclip model
sd_config["noise_aug_config"] = noise_aug_config
sd_config["image_size"] = 96
sd_config["embedding_dropout"] = 0.25
sd_config["conditioning_key"] = 'crossattn-adm'
model_config["target"] = "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
elif unet_config["in_channels"] > 4: #inpainting model
sd_config["conditioning_key"] = "hybrid"
sd_config["finetune_keys"] = None
model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
@@ -944,6 +938,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
else:
unet_config["num_heads"] = 8 #SD1.x
unclip = 'model.diffusion_model.label_emb.0.0.weight'
if unclip in sd_keys:
unet_config["num_classes"] = "sequential"
unet_config["adm_in_channels"] = sd[unclip].shape[1]
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
out = sd[k]
@@ -956,4 +955,4 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
if fp16:
model = model.half()
return (ModelPatcher(model), clip, vae)
return (ModelPatcher(model), clip, vae, clipvision)