Support SDXL inpaint models.

This commit is contained in:
comfyanonymous
2023-09-01 15:18:25 -04:00
parent c335fdf200
commit 7931ff0fd9
5 changed files with 22 additions and 16 deletions

View File

@@ -111,6 +111,9 @@ class BaseModel(torch.nn.Module):
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
def set_inpaint(self):
self.concat_keys = ("mask", "masked_image")
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
adm_inputs = []
weights = []
@@ -148,12 +151,6 @@ class SD21UNCLIP(BaseModel):
else:
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05))
class SDInpaint(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
self.concat_keys = ("mask", "masked_image")
def sdxl_pooled(args, noise_augmentor):
if "unclip_conditioning" in args:
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]