mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 20:17:30 +00:00
Support base SDXL and SDXL refiner models.
Large refactor of the model detection and loading code.
This commit is contained in:
@@ -8,11 +8,14 @@ import zipfile
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
z_empty = self.encode(self.empty_tokens)
|
||||
z_empty, _ = self.encode(self.empty_tokens)
|
||||
output = []
|
||||
first_pooled = None
|
||||
for x in token_weight_pairs:
|
||||
tokens = [list(map(lambda a: a[0], x))]
|
||||
z = self.encode(tokens)
|
||||
z, pooled = self.encode(tokens)
|
||||
if first_pooled is None:
|
||||
first_pooled = pooled
|
||||
for i in range(len(z)):
|
||||
for j in range(len(z[i])):
|
||||
weight = x[j][1]
|
||||
@@ -20,7 +23,7 @@ class ClipTokenWeightEncoder:
|
||||
output += [z]
|
||||
if (len(output) == 0):
|
||||
return self.encode(self.empty_tokens)
|
||||
return torch.cat(output, dim=-2).cpu()
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
||||
|
||||
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
@@ -50,6 +53,8 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.layer = layer
|
||||
self.layer_idx = None
|
||||
self.empty_tokens = [[49406] + [49407] * 76]
|
||||
self.text_projection = None
|
||||
self.layer_norm_hidden_state = True
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) <= 12
|
||||
@@ -112,9 +117,13 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
z = self.transformer.text_model.final_layer_norm(z)
|
||||
if self.layer_norm_hidden_state:
|
||||
z = self.transformer.text_model.final_layer_norm(z)
|
||||
|
||||
return z
|
||||
pooled_output = outputs.pooler_output
|
||||
if self.text_projection is not None:
|
||||
pooled_output = pooled_output @ self.text_projection
|
||||
return z, pooled_output
|
||||
|
||||
def encode(self, tokens):
|
||||
return self(tokens)
|
||||
@@ -204,7 +213,7 @@ def expand_directory_list(directories):
|
||||
dirs.add(root)
|
||||
return list(dirs)
|
||||
|
||||
def load_embed(embedding_name, embedding_directory):
|
||||
def load_embed(embedding_name, embedding_directory, embedding_size):
|
||||
if isinstance(embedding_directory, str):
|
||||
embedding_directory = [embedding_directory]
|
||||
|
||||
@@ -253,13 +262,23 @@ def load_embed(embedding_name, embedding_directory):
|
||||
if embed_out is None:
|
||||
if 'string_to_param' in embed:
|
||||
values = embed['string_to_param'].values()
|
||||
embed_out = next(iter(values))
|
||||
elif isinstance(embed, list):
|
||||
out_list = []
|
||||
for x in range(len(embed)):
|
||||
for k in embed[x]:
|
||||
t = embed[x][k]
|
||||
if t.shape[-1] != embedding_size:
|
||||
continue
|
||||
out_list.append(t.reshape(-1, t.shape[-1]))
|
||||
embed_out = torch.cat(out_list, dim=0)
|
||||
else:
|
||||
values = embed.values()
|
||||
embed_out = next(iter(values))
|
||||
embed_out = next(iter(values))
|
||||
return embed_out
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
@@ -275,17 +294,18 @@ class SD1Tokenizer:
|
||||
self.embedding_directory = embedding_directory
|
||||
self.max_word_length = 8
|
||||
self.embedding_identifier = "embedding:"
|
||||
self.embedding_size = embedding_size
|
||||
|
||||
def _try_get_embedding(self, embedding_name:str):
|
||||
'''
|
||||
Takes a potential embedding name and tries to retrieve it.
|
||||
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
||||
'''
|
||||
embed = load_embed(embedding_name, self.embedding_directory)
|
||||
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size)
|
||||
if embed is None:
|
||||
stripped = embedding_name.strip(',')
|
||||
if len(stripped) < len(embedding_name):
|
||||
embed = load_embed(stripped, self.embedding_directory)
|
||||
embed = load_embed(stripped, self.embedding_directory, self.embedding_size)
|
||||
return (embed, embedding_name[len(stripped):])
|
||||
return (embed, "")
|
||||
|
||||
|
Reference in New Issue
Block a user