mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 12:37:01 +00:00
Refactor cond_concat into model object.
This commit is contained in:
@@ -26,6 +26,7 @@ class BaseModel(torch.nn.Module):
|
||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||
if self.adm_channels is None:
|
||||
self.adm_channels = 0
|
||||
self.inpaint_model = False
|
||||
print("model_type", model_type.name)
|
||||
print("adm", self.adm_channels)
|
||||
|
||||
@@ -71,6 +72,37 @@ class BaseModel(torch.nn.Module):
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
def cond_concat(self, **kwargs):
|
||||
if self.inpaint_model:
|
||||
concat_keys = ("mask", "masked_image")
|
||||
cond_concat = []
|
||||
denoise_mask = kwargs.get("denoise_mask", None)
|
||||
latent_image = kwargs.get("latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
def blank_inpaint_image_like(latent_image):
|
||||
blank_image = torch.ones_like(latent_image)
|
||||
# these are the values for "zero" in pixel space translated to latent space
|
||||
blank_image[:,0] *= 0.8223
|
||||
blank_image[:,1] *= -0.6876
|
||||
blank_image[:,2] *= 0.6364
|
||||
blank_image[:,3] *= 0.1380
|
||||
return blank_image
|
||||
|
||||
for ck in concat_keys:
|
||||
if denoise_mask is not None:
|
||||
if ck == "mask":
|
||||
cond_concat.append(denoise_mask[:,:1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
|
||||
else:
|
||||
if ck == "mask":
|
||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(blank_inpaint_image_like(noise))
|
||||
return cond_concat
|
||||
return None
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
to_load = {}
|
||||
keys = list(sd.keys())
|
||||
@@ -112,7 +144,7 @@ class BaseModel(torch.nn.Module):
|
||||
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
||||
|
||||
def set_inpaint(self):
|
||||
self.concat_keys = ("mask", "masked_image")
|
||||
self.inpaint_model = True
|
||||
|
||||
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
|
||||
adm_inputs = []
|
||||
|
Reference in New Issue
Block a user