Move some functions to utils.py

This commit is contained in:
comfyanonymous
2023-09-02 22:33:37 -04:00
parent 766c7b3815
commit a74c5dbf37
3 changed files with 23 additions and 24 deletions

View File

@@ -68,7 +68,7 @@ class SD20(supported_models_base.BASE):
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
replace_prefix[""] = "cond_stage_model.model."
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
return state_dict
@@ -120,7 +120,7 @@ class SDXLRefiner(supported_models_base.BASE):
keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
@@ -129,7 +129,7 @@ class SDXLRefiner(supported_models_base.BASE):
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g
def clip_target(self):
@@ -167,8 +167,8 @@ class SDXL(supported_models_base.BASE):
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
@@ -183,7 +183,7 @@ class SDXL(supported_models_base.BASE):
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
replace_prefix["clip_l"] = "conditioner.embedders.0"
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g
def clip_target(self):