Support multiple text encoder configurations on SD3.

This commit is contained in:
comfyanonymous
2024-06-11 13:14:43 -04:00
parent 1c34d338d7
commit 5889b7ca0a
3 changed files with 91 additions and 35 deletions

View File

@@ -54,7 +54,7 @@ class SD15(supported_models_base.BASE):
replace_prefix = {"clip_l.": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def clip_target(self):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
class SD20(supported_models_base.BASE):
@@ -97,7 +97,7 @@ class SD20(supported_models_base.BASE):
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
return state_dict
def clip_target(self):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
class SD21UnclipL(SD20):
@@ -159,7 +159,7 @@ class SDXLRefiner(supported_models_base.BASE):
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g
def clip_target(self):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
class SDXL(supported_models_base.BASE):
@@ -228,7 +228,7 @@ class SDXL(supported_models_base.BASE):
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g
def clip_target(self):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
class SSD1B(SDXL):
@@ -299,7 +299,7 @@ class SVD_img2vid(supported_models_base.BASE):
out = model_base.SVD_img2vid(self, device=device)
return out
def clip_target(self):
def clip_target(self, state_dict={}):
return None
class SV3D_u(SVD_img2vid):
@@ -365,7 +365,7 @@ class Stable_Zero123(supported_models_base.BASE):
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
return out
def clip_target(self):
def clip_target(self, state_dict={}):
return None
class SD_X4Upscaler(SD20):
@@ -439,7 +439,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
out = model_base.StableCascade_C(self, device=device)
return out
def clip_target(self):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
class Stable_Cascade_B(Stable_Cascade_C):
@@ -501,14 +501,29 @@ class SD3(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.SD3
text_encoder_key_prefix = ["text_encoders."] #TODO?
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SD3(self, device=device)
return out
def clip_target(self):
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.SD3ClipModel) #TODO?
def clip_target(self, state_dict={}):
clip_l = False
clip_g = False
t5 = False
pref = self.text_encoder_key_prefix[0]
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
clip_l = True
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
clip_g = True
if "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) in state_dict:
t5 = True
class SD3ClipModel(sd3_clip.SD3ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, device=device, dtype=dtype)
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel)
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]