Refactor of sampler code to deal more easily with different model types.

This commit is contained in:
comfyanonymous
2023-07-17 01:22:12 -04:00
parent ac9c038ac2
commit 3ded1a3a04
8 changed files with 68 additions and 53 deletions

View File

@@ -53,13 +53,13 @@ class SD20(supported_models_base.BASE):
latent_format = latent_formats.SD15
def v_prediction(self, state_dict, prefix=""):
def model_type(self, state_dict, prefix=""):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
out = state_dict[k]
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
return True
return False
return model_base.ModelType.V_PREDICTION
return model_base.ModelType.EPS
def process_clip_state_dict(self, state_dict):
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
@@ -145,8 +145,14 @@ class SDXL(supported_models_base.BASE):
latent_format = latent_formats.SDXL
def model_type(self, state_dict, prefix=""):
if "v_pred" in state_dict:
return model_base.ModelType.V_PREDICTION
else:
return model_base.ModelType.EPS
def get_model(self, state_dict, prefix=""):
return model_base.SDXL(self)
return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix))
def process_clip_state_dict(self, state_dict):
keys_to_replace = {}