Support base SDXL and SDXL refiner models.

Large refactor of the model detection and loading code.
This commit is contained in:
comfyanonymous
2023-06-22 13:03:50 -04:00
parent 9fccf4aa03
commit f87ec10a97
16 changed files with 754 additions and 289 deletions

View File

@@ -8,11 +8,14 @@ import zipfile
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
z_empty = self.encode(self.empty_tokens)
z_empty, _ = self.encode(self.empty_tokens)
output = []
first_pooled = None
for x in token_weight_pairs:
tokens = [list(map(lambda a: a[0], x))]
z = self.encode(tokens)
z, pooled = self.encode(tokens)
if first_pooled is None:
first_pooled = pooled
for i in range(len(z)):
for j in range(len(z[i])):
weight = x[j][1]
@@ -20,7 +23,7 @@ class ClipTokenWeightEncoder:
output += [z]
if (len(output) == 0):
return self.encode(self.empty_tokens)
return torch.cat(output, dim=-2).cpu()
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
@@ -50,6 +53,8 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer = layer
self.layer_idx = None
self.empty_tokens = [[49406] + [49407] * 76]
self.text_projection = None
self.layer_norm_hidden_state = True
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) <= 12
@@ -112,9 +117,13 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
z = self.transformer.text_model.final_layer_norm(z)
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
return z
pooled_output = outputs.pooler_output
if self.text_projection is not None:
pooled_output = pooled_output @ self.text_projection
return z, pooled_output
def encode(self, tokens):
return self(tokens)
@@ -204,7 +213,7 @@ def expand_directory_list(directories):
dirs.add(root)
return list(dirs)
def load_embed(embedding_name, embedding_directory):
def load_embed(embedding_name, embedding_directory, embedding_size):
if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory]
@@ -253,13 +262,23 @@ def load_embed(embedding_name, embedding_directory):
if embed_out is None:
if 'string_to_param' in embed:
values = embed['string_to_param'].values()
embed_out = next(iter(values))
elif isinstance(embed, list):
out_list = []
for x in range(len(embed)):
for k in embed[x]:
t = embed[x][k]
if t.shape[-1] != embedding_size:
continue
out_list.append(t.reshape(-1, t.shape[-1]))
embed_out = torch.cat(out_list, dim=0)
else:
values = embed.values()
embed_out = next(iter(values))
embed_out = next(iter(values))
return embed_out
class SD1Tokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
@@ -275,17 +294,18 @@ class SD1Tokenizer:
self.embedding_directory = embedding_directory
self.max_word_length = 8
self.embedding_identifier = "embedding:"
self.embedding_size = embedding_size
def _try_get_embedding(self, embedding_name:str):
'''
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
'''
embed = load_embed(embedding_name, self.embedding_directory)
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory)
embed = load_embed(stripped, self.embedding_directory, self.embedding_size)
return (embed, embedding_name[len(stripped):])
return (embed, "")