Move latent scale factor from VAE to model.

This commit is contained in:
comfyanonymous
2023-06-23 02:14:12 -04:00
parent 30a3861946
commit 8607c2d42d
7 changed files with 73 additions and 33 deletions

View File

@@ -7,6 +7,7 @@ from . import sd2_clip
from . import sdxl_clip
from . import supported_models_base
from . import latent_formats
class SD15(supported_models_base.BASE):
unet_config = {
@@ -21,7 +22,7 @@ class SD15(supported_models_base.BASE):
"num_head_channels": -1,
}
vae_scale_factor = 0.18215
latent_format = latent_formats.SD15
def process_clip_state_dict(self, state_dict):
k = list(state_dict.keys())
@@ -48,7 +49,7 @@ class SD20(supported_models_base.BASE):
"adm_in_channels": None,
}
vae_scale_factor = 0.18215
latent_format = latent_formats.SD15
def v_prediction(self, state_dict):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
@@ -97,10 +98,10 @@ class SDXLRefiner(supported_models_base.BASE):
"transformer_depth": [0, 4, 4, 0],
}
vae_scale_factor = 0.13025
latent_format = latent_formats.SDXL
def get_model(self, state_dict):
return model_base.SDXLRefiner(self.unet_config)
return model_base.SDXLRefiner(self)
def process_clip_state_dict(self, state_dict):
keys_to_replace = {}
@@ -124,10 +125,10 @@ class SDXL(supported_models_base.BASE):
"adm_in_channels": 2816
}
vae_scale_factor = 0.13025
latent_format = latent_formats.SDXL
def get_model(self, state_dict):
return model_base.SDXL(self.unet_config)
return model_base.SDXL(self)
def process_clip_state_dict(self, state_dict):
keys_to_replace = {}