Support SDXL inpaint models.

This commit is contained in:
comfyanonymous
2023-09-01 15:18:25 -04:00
parent c335fdf200
commit 7931ff0fd9
5 changed files with 22 additions and 16 deletions

View File

@@ -355,13 +355,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
model_config.unet_config = unet_config
if config['model']["target"].endswith("LatentInpaintDiffusion"):
model = model_base.SDInpaint(model_config, model_type=model_type)
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
else:
model = model_base.BaseModel(model_config, model_type=model_type)
if config['model']["target"].endswith("LatentInpaintDiffusion"):
model.set_inpaint()
if fp16:
model = model.half()