From 8489cba1405f222f4675c120aee4a3722affb3f8 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Thu, 13 Apr 2023 22:01:01 +0200 Subject: [PATCH] add unique ID per word/embedding for tokenizer --- comfy/sd1_clip.py | 117 ++++++++++++++++++++++++++++------------------ 1 file changed, 71 insertions(+), 46 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 4f51657c3..3dd8262ac 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -224,60 +224,85 @@ class SD1Tokenizer: self.inv_vocab = {v: k for k, v in vocab.items()} self.embedding_directory = embedding_directory self.max_word_length = 8 + self.embedding_identifier = "embedding:" - def tokenize_with_weights(self, text): + def _try_get_embedding(self, name:str): + ''' + Takes a potential embedding name and tries to retrieve it. + Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. + ''' + embedding_name = name[len(self.embedding_identifier):].strip('\n') + embed = load_embed(embedding_name, self.embedding_directory) + if embed is None: + stripped = embedding_name.strip(',') + if len(stripped) < len(embedding_name): + embed = load_embed(stripped, self.embedding_directory) + return (embed, embedding_name[len(stripped):]) + return (embed, "") + + + def tokenize_with_weights(self, text:str): + ''' + Takes a prompt and converts it to a list of (token, weight, word id) elements. + Tokens can both be integer tokens and pre computed CLIP tensors. + Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. + Returned list has the dimensions NxM where M is the input size of CLIP + ''' text = escape_important(text) parsed_weights = token_weights(text, 1.0) + #tokenize words tokens = [] - for t in parsed_weights: - to_tokenize = unescape_important(t[0]).replace("\n", " ").split(' ') - while len(to_tokenize) > 0: - word = to_tokenize.pop(0) - temp_tokens = [] - embedding_identifier = "embedding:" - if word.startswith(embedding_identifier) and self.embedding_directory is not None: - embedding_name = word[len(embedding_identifier):].strip('\n') - embed = load_embed(embedding_name, self.embedding_directory) + for weighted_segment, weight in parsed_weights: + to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ') + to_tokenize = [x for x in to_tokenize if x != ""] + for word in to_tokenize: + #if we find an embedding, deal with the embedding + if word.startswith(self.embedding_identifier) and self.embedding_directory is not None: + embed, leftover = self._try_get_embedding(word) if embed is None: - stripped = embedding_name.strip(',') - if len(stripped) < len(embedding_name): - embed = load_embed(stripped, self.embedding_directory) - if embed is not None: - to_tokenize.insert(0, embedding_name[len(stripped):]) - - if embed is not None: - if len(embed.shape) == 1: - temp_tokens += [(embed, t[1])] - else: - for x in range(embed.shape[0]): - temp_tokens += [(embed[x], t[1])] + print(f"warning, embedding:{word} does not exist, ignoring") else: - print("warning, embedding:{} does not exist, ignoring".format(embedding_name)) - elif len(word) > 0: - tt = self.tokenizer(word)["input_ids"][1:-1] - for x in tt: - temp_tokens += [(x, t[1])] - tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section) + if len(embed.shape) == 1: + tokens.append([(embed, weight)]) + else: + tokens.append([(embed[x], weight) for x in range(embed.shape[0])]) + #if we accidentally have leftover text, continue parsing using leftover, else move on to next word + if leftover != "": + word = leftover + else: + continue + #parse word + tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]]) + + #reshape token array to CLIP input size + batched_tokens = [] + batch = [] + batched_tokens.append(batch) + for i, t_group in enumerate(tokens): + #start a new batch if there is not enough room + if len(t_group) + len(batch) > self.max_tokens_per_section: + remaining_length = self.max_tokens_per_section - len(batch) + #fill remaining space depending on length of tokens + if len(t_group) > self.max_word_length: + #put part of group of tokens in the batch + batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) + t_group = t_group[remaining_length:] + else: + #filler tokens + batch.extend([(self.end_token, 1.0, 0)] * remaining_length) + batch = [] + batched_tokens.append(batch) + #put current group of tokens in the batch + batch.extend([(t,w,i+1) for t,w in t_group]) + + #fill last batch + batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch))) + + #add start and end tokens + batched_tokens = [[(self.start_token, 1.0, 0)] + x + [(self.end_token, 1.0, 0)] for x in batched_tokens] + return batched_tokens - #try not to split words in different sections - if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length): - for x in range(tokens_left): - tokens += [(self.end_token, 1.0)] - tokens += temp_tokens - - out_tokens = [] - for x in range(0, len(tokens), self.max_tokens_per_section): - o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))] - o_token += [(self.end_token, 1.0)] - if self.pad_with_end: - o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token)) - else: - o_token +=[(0, 1.0)] * (self.max_length - len(o_token)) - - out_tokens += [o_token] - - return out_tokens def untokenize(self, token_weight_pair): return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))