mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 21:45:06 +00:00
Add support for unCLIP SD2.x models.
See _for_testing/unclip in the UI for the new nodes. unCLIPCheckpointLoader is used to load them. unCLIPConditioning is used to add the image cond and takes as input a CLIPVisionEncode output which has been moved to the conditioning section.
This commit is contained in:
@@ -1801,3 +1801,75 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
||||
log = super().log_images(*args, **kwargs)
|
||||
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
|
||||
return log
|
||||
|
||||
|
||||
class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
|
||||
def __init__(self, embedder_config=None, embedding_key="jpg", embedding_dropout=0.5,
|
||||
freeze_embedder=True, noise_aug_config=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.embed_key = embedding_key
|
||||
self.embedding_dropout = embedding_dropout
|
||||
# self._init_embedder(embedder_config, freeze_embedder)
|
||||
self._init_noise_aug(noise_aug_config)
|
||||
|
||||
def _init_embedder(self, config, freeze=True):
|
||||
embedder = instantiate_from_config(config)
|
||||
if freeze:
|
||||
self.embedder = embedder.eval()
|
||||
self.embedder.train = disabled_train
|
||||
for param in self.embedder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _init_noise_aug(self, config):
|
||||
if config is not None:
|
||||
# use the KARLO schedule for noise augmentation on CLIP image embeddings
|
||||
noise_augmentor = instantiate_from_config(config)
|
||||
assert isinstance(noise_augmentor, nn.Module)
|
||||
noise_augmentor = noise_augmentor.eval()
|
||||
noise_augmentor.train = disabled_train
|
||||
self.noise_augmentor = noise_augmentor
|
||||
else:
|
||||
self.noise_augmentor = None
|
||||
|
||||
def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
|
||||
outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
|
||||
z, c = outputs[0], outputs[1]
|
||||
img = batch[self.embed_key][:bs]
|
||||
img = rearrange(img, 'b h w c -> b c h w')
|
||||
c_adm = self.embedder(img)
|
||||
if self.noise_augmentor is not None:
|
||||
c_adm, noise_level_emb = self.noise_augmentor(c_adm)
|
||||
# assume this gives embeddings of noise levels
|
||||
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
||||
if self.training:
|
||||
c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
|
||||
device=c_adm.device)[:, None]) * c_adm
|
||||
all_conds = {"c_crossattn": [c], "c_adm": c_adm}
|
||||
noutputs = [z, all_conds]
|
||||
noutputs.extend(outputs[2:])
|
||||
return noutputs
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, n_row=4, **kwargs):
|
||||
log = dict()
|
||||
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True,
|
||||
return_original_cond=True)
|
||||
log["inputs"] = x
|
||||
log["reconstruction"] = xrec
|
||||
assert self.model.conditioning_key is not None
|
||||
assert self.cond_stage_key in ["caption", "txt"]
|
||||
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
|
||||
log["conditioning"] = xc
|
||||
uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', ''))
|
||||
unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.)
|
||||
|
||||
uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
|
||||
ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext
|
||||
with ema_scope(f"Sampling"):
|
||||
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True,
|
||||
ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.),
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc_, )
|
||||
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
||||
log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
||||
return log
|
||||
|
@@ -307,7 +307,16 @@ def model_wrapper(
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t_continuous] * 2)
|
||||
c_in = torch.cat([unconditional_condition, condition])
|
||||
if isinstance(condition, dict):
|
||||
assert isinstance(unconditional_condition, dict)
|
||||
c_in = dict()
|
||||
for k in condition:
|
||||
if isinstance(condition[k], list):
|
||||
c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_condition[k], condition[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_condition, condition])
|
||||
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
||||
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
||||
|
||||
|
@@ -3,7 +3,6 @@ import torch
|
||||
|
||||
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||
|
||||
|
||||
MODEL_TYPES = {
|
||||
"eps": "noise",
|
||||
"v": "v"
|
||||
@@ -51,12 +50,20 @@ class DPMSolverSampler(object):
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||
if isinstance(ctmp, torch.Tensor):
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
if isinstance(conditioning, torch.Tensor):
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
@@ -83,6 +90,7 @@ class DPMSolverSampler(object):
|
||||
)
|
||||
|
||||
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
|
||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
|
||||
lower_order_final=True)
|
||||
|
||||
return x.to(device), None
|
||||
return x.to(device), None
|
||||
|
Reference in New Issue
Block a user