mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 13:35:05 +00:00
Support multiple text encoder configurations on SD3.
This commit is contained in:
@@ -5,6 +5,7 @@ import comfy.t5
|
||||
import torch
|
||||
import os
|
||||
import comfy.model_management
|
||||
import logging
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
@@ -43,42 +44,82 @@ class SD3Tokenizer:
|
||||
return self.clip_g.untokenize(token_weight_pair)
|
||||
|
||||
class SD3ClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
def __init__(self, clip_l=True, clip_g=True, t5=True, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
|
||||
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype)
|
||||
if clip_l:
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
|
||||
else:
|
||||
self.clip_l = None
|
||||
|
||||
if clip_g:
|
||||
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
|
||||
else:
|
||||
self.clip_g = None
|
||||
|
||||
if t5:
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype)
|
||||
else:
|
||||
self.t5xxl = None
|
||||
|
||||
logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}".format(clip_l, clip_g, t5))
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.clip_l.set_clip_options(options)
|
||||
self.clip_g.set_clip_options(options)
|
||||
self.t5xxl.set_clip_options(options)
|
||||
if self.clip_l is not None:
|
||||
self.clip_l.set_clip_options(options)
|
||||
if self.clip_g is not None:
|
||||
self.clip_g.set_clip_options(options)
|
||||
if self.t5xxl is not None:
|
||||
self.t5xxl.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.clip_g.reset_clip_options()
|
||||
self.clip_l.reset_clip_options()
|
||||
self.t5xxl.reset_clip_options()
|
||||
if self.clip_l is not None:
|
||||
self.clip_l.reset_clip_options()
|
||||
if self.clip_g is not None:
|
||||
self.clip_g.reset_clip_options()
|
||||
if self.t5xxl is not None:
|
||||
self.t5xxl.reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_l = token_weight_pairs["l"]
|
||||
token_weight_pairs_g = token_weight_pairs["g"]
|
||||
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
|
||||
lg_out = None
|
||||
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
out = lg_out
|
||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
else:
|
||||
pooled = torch.zeros((1, 1280 + 768), device=comfy.model_management.intermediate_device())
|
||||
pooled = None
|
||||
out = None
|
||||
|
||||
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
|
||||
if lg_out is not None:
|
||||
out = torch.cat([lg_out, t5_out], dim=-2)
|
||||
else:
|
||||
out = t5_out
|
||||
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||
if self.clip_l is not None:
|
||||
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
else:
|
||||
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if self.clip_g is not None:
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||
if lg_out is not None:
|
||||
lg_out = torch.cat([lg_out, g_out], dim=-1)
|
||||
else:
|
||||
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
||||
else:
|
||||
g_out = None
|
||||
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if lg_out is not None:
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
out = lg_out
|
||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
if self.t5xxl is not None:
|
||||
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
|
||||
if lg_out is not None:
|
||||
out = torch.cat([lg_out, t5_out], dim=-2)
|
||||
else:
|
||||
out = t5_out
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros((1, 77, 4096), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if pooled is None:
|
||||
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
||||
|
||||
return out, pooled
|
||||
|
||||
|
Reference in New Issue
Block a user