Refactor unCLIP noise augment out of samplers.py

This commit is contained in:
comfyanonymous
2023-06-11 04:01:18 -04:00
parent 7b2f09b5fa
commit c64ca8c0b2
3 changed files with 43 additions and 39 deletions

View File

@@ -60,6 +60,37 @@ class SD21UNCLIP(BaseModel):
super().__init__(unet_config, v_prediction)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
def encode_adm(self, **kwargs):
unclip_conditioning = kwargs.get("unclip_conditioning", None)
device = kwargs["device"]
if unclip_conditioning is not None:
adm_inputs = []
weights = []
noise_aug = []
for unclip_cond in unclip_conditioning:
adm_cond = unclip_cond["clip_vision_output"].image_embeds
weight = unclip_cond["strength"]
noise_augment = unclip_cond["noise_augmentation"]
noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = self.noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out)
if len(noise_aug) > 1:
adm_out = torch.stack(adm_inputs).sum(0)
#TODO: add a way to control this
noise_augment = 0.05
noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = self.noise_augmentor(adm_out[:, :self.noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1)
else:
adm_out = torch.zeros((1, self.adm_channels))
return adm_out
class SDInpaint(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)