Add support for unCLIP SD2.x models.

See _for_testing/unclip in the UI for the new nodes.

unCLIPCheckpointLoader is used to load them.

unCLIPConditioning is used to add the image cond and takes as input a
CLIPVisionEncode output which has been moved to the conditioning section.
This commit is contained in:
comfyanonymous
2023-04-01 23:19:15 -04:00
parent 0d972b85e6
commit 809bcc8ceb
17 changed files with 593 additions and 113 deletions

View File

@@ -1801,3 +1801,75 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
log = super().log_images(*args, **kwargs)
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
return log
class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
def __init__(self, embedder_config=None, embedding_key="jpg", embedding_dropout=0.5,
freeze_embedder=True, noise_aug_config=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.embed_key = embedding_key
self.embedding_dropout = embedding_dropout
# self._init_embedder(embedder_config, freeze_embedder)
self._init_noise_aug(noise_aug_config)
def _init_embedder(self, config, freeze=True):
embedder = instantiate_from_config(config)
if freeze:
self.embedder = embedder.eval()
self.embedder.train = disabled_train
for param in self.embedder.parameters():
param.requires_grad = False
def _init_noise_aug(self, config):
if config is not None:
# use the KARLO schedule for noise augmentation on CLIP image embeddings
noise_augmentor = instantiate_from_config(config)
assert isinstance(noise_augmentor, nn.Module)
noise_augmentor = noise_augmentor.eval()
noise_augmentor.train = disabled_train
self.noise_augmentor = noise_augmentor
else:
self.noise_augmentor = None
def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
z, c = outputs[0], outputs[1]
img = batch[self.embed_key][:bs]
img = rearrange(img, 'b h w c -> b c h w')
c_adm = self.embedder(img)
if self.noise_augmentor is not None:
c_adm, noise_level_emb = self.noise_augmentor(c_adm)
# assume this gives embeddings of noise levels
c_adm = torch.cat((c_adm, noise_level_emb), 1)
if self.training:
c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
device=c_adm.device)[:, None]) * c_adm
all_conds = {"c_crossattn": [c], "c_adm": c_adm}
noutputs = [z, all_conds]
noutputs.extend(outputs[2:])
return noutputs
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, **kwargs):
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True,
return_original_cond=True)
log["inputs"] = x
log["reconstruction"] = xrec
assert self.model.conditioning_key is not None
assert self.cond_stage_key in ["caption", "txt"]
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', ''))
unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.)
uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext
with ema_scope(f"Sampling"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True,
ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.),
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_, )
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
return log

View File

@@ -307,7 +307,16 @@ def model_wrapper(
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t_continuous] * 2)
c_in = torch.cat([unconditional_condition, condition])
if isinstance(condition, dict):
assert isinstance(unconditional_condition, dict)
c_in = dict()
for k in condition:
if isinstance(condition[k], list):
c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))]
else:
c_in[k] = torch.cat([unconditional_condition[k], condition[k]])
else:
c_in = torch.cat([unconditional_condition, condition])
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
return noise_uncond + guidance_scale * (noise - noise_uncond)

View File

@@ -3,7 +3,6 @@ import torch
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
MODEL_TYPES = {
"eps": "noise",
"v": "v"
@@ -51,12 +50,20 @@ class DPMSolverSampler(object):
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
if isinstance(ctmp, torch.Tensor):
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
if isinstance(conditioning, torch.Tensor):
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling
C, H, W = shape
@@ -83,6 +90,7 @@ class DPMSolverSampler(object):
)
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
lower_order_final=True)
return x.to(device), None
return x.to(device), None