SD1 and SD2 clip and tokenizer code is now more similar to the SDXL one.

This commit is contained in:
comfyanonymous
2023-10-27 15:54:04 -04:00
parent 6ec3f12c6e
commit e60ca6929a
5 changed files with 69 additions and 30 deletions

View File

@@ -38,8 +38,15 @@ class SD15(supported_models_base.BASE):
if ids.dtype == torch.float32:
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
replace_prefix = {}
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"clip_l.": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def clip_target(self):
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
@@ -62,12 +69,12 @@ class SD20(supported_models_base.BASE):
return model_base.ModelType.EPS
def process_clip_state_dict(self, state_dict):
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
replace_prefix[""] = "cond_stage_model.model."
replace_prefix["clip_h"] = "cond_stage_model.model"
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
return state_dict