Refactor of sampler code to deal more easily with different model types.

This commit is contained in:
comfyanonymous
2023-07-17 01:22:12 -04:00
parent ac9c038ac2
commit 3ded1a3a04
8 changed files with 68 additions and 53 deletions

View File

@@ -41,8 +41,8 @@ class BASE:
return False
return True
def v_prediction(self, state_dict, prefix=""):
return False
def model_type(self, state_dict, prefix=""):
return model_base.ModelType.EPS
def inpaint_model(self):
return self.unet_config["in_channels"] > 4
@@ -55,11 +55,11 @@ class BASE:
def get_model(self, state_dict, prefix=""):
if self.inpaint_model():
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict, prefix))
return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix))
elif self.noise_aug_config is not None:
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict, prefix))
return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix))
else:
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict, prefix))
return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix))
def process_clip_state_dict(self, state_dict):
return state_dict