Support loading unet files in diffusers format.

This commit is contained in:
comfyanonymous
2023-07-05 17:34:45 -04:00
parent e57cba4c61
commit af7a49916b
9 changed files with 123 additions and 15 deletions

View File

@@ -41,7 +41,7 @@ class BASE:
return False
return True
def v_prediction(self, state_dict):
def v_prediction(self, state_dict, prefix=""):
return False
def inpaint_model(self):
@@ -53,13 +53,13 @@ class BASE:
for x in self.unet_extra_config:
self.unet_config[x] = self.unet_extra_config[x]
def get_model(self, state_dict):
def get_model(self, state_dict, prefix=""):
if self.inpaint_model():
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict))
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict, prefix))
elif self.noise_aug_config is not None:
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict))
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict, prefix))
else:
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict))
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict, prefix))
def process_clip_state_dict(self, state_dict):
return state_dict