Simpler base model code.

This commit is contained in:
comfyanonymous
2023-06-09 12:24:24 -04:00
parent 4b0b516544
commit de142eaad5
4 changed files with 163 additions and 74 deletions

View File

@@ -248,7 +248,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
c['transformer_options'] = transformer_options
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
model_management.throw_exception_if_processing_interrupted()
@@ -460,36 +460,42 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
uncond[temp[1]] = [o[0], n]
def encode_adm(noise_augmentor, conds, batch_size, device):
def encode_adm(conds, batch_size, device, noise_augmentor=None):
for t in range(len(conds)):
x = conds[t]
if 'adm' in x[1]:
adm_inputs = []
weights = []
noise_aug = []
adm_in = x[1]["adm"]
for adm_c in adm_in:
adm_cond = adm_c[0].image_embeds
weight = adm_c[1]
noise_augment = adm_c[2]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out)
adm_out = None
if noise_augmentor is not None:
if 'adm' in x[1]:
adm_inputs = []
weights = []
noise_aug = []
adm_in = x[1]["adm"]
for adm_c in adm_in:
adm_cond = adm_c[0].image_embeds
weight = adm_c[1]
noise_augment = adm_c[2]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out)
if len(noise_aug) > 1:
adm_out = torch.stack(adm_inputs).sum(0)
#TODO: add a way to control this
noise_augment = 0.05
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1)
if len(noise_aug) > 1:
adm_out = torch.stack(adm_inputs).sum(0)
#TODO: add a way to control this
noise_augment = 0.05
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1)
else:
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
else:
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
x[1] = x[1].copy()
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size)
if 'adm' in x[1]:
adm_out = x[1]["adm"].to(device)
if adm_out is not None:
x[1] = x[1].copy()
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size)
return conds
@@ -591,14 +597,17 @@ class KSampler:
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if self.model.model.diffusion_model.dtype == torch.float16:
if self.model.get_dtype() == torch.float16:
precision_scope = torch.autocast
else:
precision_scope = contextlib.nullcontext
if hasattr(self.model, 'noise_augmentor'): #unclip
positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device)
negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device)
if self.model.is_adm():
noise_augmentor = None
if hasattr(self.model, 'noise_augmentor'): #unclip
noise_augmentor = self.model.noise_augmentor
positive = encode_adm(positive, noise.shape[0], self.device, noise_augmentor)
negative = encode_adm(negative, noise.shape[0], self.device, noise_augmentor)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}