mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 03:58:22 +00:00
Support multiple text encoder configurations on SD3.
This commit is contained in:
@@ -54,7 +54,7 @@ class SD15(supported_models_base.BASE):
|
||||
replace_prefix = {"clip_l.": "cond_stage_model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
||||
|
||||
class SD20(supported_models_base.BASE):
|
||||
@@ -97,7 +97,7 @@ class SD20(supported_models_base.BASE):
|
||||
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||
return state_dict
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
||||
|
||||
class SD21UnclipL(SD20):
|
||||
@@ -159,7 +159,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
return state_dict_g
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
||||
|
||||
class SDXL(supported_models_base.BASE):
|
||||
@@ -228,7 +228,7 @@ class SDXL(supported_models_base.BASE):
|
||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
return state_dict_g
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
||||
|
||||
class SSD1B(SDXL):
|
||||
@@ -299,7 +299,7 @@ class SVD_img2vid(supported_models_base.BASE):
|
||||
out = model_base.SVD_img2vid(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class SV3D_u(SVD_img2vid):
|
||||
@@ -365,7 +365,7 @@ class Stable_Zero123(supported_models_base.BASE):
|
||||
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
|
||||
return out
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class SD_X4Upscaler(SD20):
|
||||
@@ -439,7 +439,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
|
||||
out = model_base.StableCascade_C(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
|
||||
|
||||
class Stable_Cascade_B(Stable_Cascade_C):
|
||||
@@ -501,14 +501,29 @@ class SD3(supported_models_base.BASE):
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.SD3
|
||||
text_encoder_key_prefix = ["text_encoders."] #TODO?
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SD3(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self):
|
||||
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.SD3ClipModel) #TODO?
|
||||
def clip_target(self, state_dict={}):
|
||||
clip_l = False
|
||||
clip_g = False
|
||||
t5 = False
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||
clip_l = True
|
||||
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||
clip_g = True
|
||||
if "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) in state_dict:
|
||||
t5 = True
|
||||
|
||||
class SD3ClipModel(sd3_clip.SD3ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, device=device, dtype=dtype)
|
||||
|
||||
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, SD3ClipModel)
|
||||
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]
|
||||
|
Reference in New Issue
Block a user