Refactor of sampler code to deal more easily with different model types.

This commit is contained in:
comfyanonymous
2023-07-17 01:22:12 -04:00
parent ac9c038ac2
commit 3ded1a3a04
8 changed files with 68 additions and 53 deletions

View File

@@ -1008,11 +1008,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if "noise_aug_config" in model_config_params:
noise_aug_config = model_config_params["noise_aug_config"]
v_prediction = False
model_type = model_base.ModelType.EPS
if "parameterization" in model_config_params:
if model_config_params["parameterization"] == "v":
v_prediction = True
model_type = model_base.ModelType.V_PREDICTION
clip = None
vae = None
@@ -1032,11 +1032,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
if config['model']["target"].endswith("LatentInpaintDiffusion"):
model = model_base.SDInpaint(model_config, v_prediction=v_prediction)
model = model_base.SDInpaint(model_config, model_type=model_type)
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], v_prediction=v_prediction)
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
else:
model = model_base.BaseModel(model_config, v_prediction=v_prediction)
model = model_base.BaseModel(model_config, model_type=model_type)
if fp16:
model = model.half()