Move latent scale factor from VAE to model.

This commit is contained in:
comfyanonymous
2023-06-23 02:14:12 -04:00
parent 30a3861946
commit 8607c2d42d
7 changed files with 73 additions and 33 deletions

View File

@@ -49,16 +49,17 @@ class BASE:
def __init__(self, unet_config):
self.unet_config = unet_config
self.latent_format = self.latent_format()
for x in self.unet_extra_config:
self.unet_config[x] = self.unet_extra_config[x]
def get_model(self, state_dict):
if self.inpaint_model():
return model_base.SDInpaint(self.unet_config, v_prediction=self.v_prediction(state_dict))
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict))
elif self.noise_aug_config is not None:
return model_base.SD21UNCLIP(self.unet_config, self.noise_aug_config, v_prediction=self.v_prediction(state_dict))
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict))
else:
return model_base.BaseModel(self.unet_config, v_prediction=self.v_prediction(state_dict))
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict))
def process_clip_state_dict(self, state_dict):
return state_dict