Make highvram and normalvram shift the text encoders to vram and back.

This is faster on big text encoder models than running it on the CPU.
This commit is contained in:
comfyanonymous
2023-07-01 12:37:23 -04:00
parent fa1959e3ef
commit 97ee230682
3 changed files with 46 additions and 20 deletions

View File

@@ -5,6 +5,8 @@ import comfy.ops
import torch
import traceback
import zipfile
from . import model_management
import contextlib
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
@@ -46,7 +48,6 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
@@ -95,7 +96,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
out_tokens += [tokens_temp]
if len(embedding_weights) > 0:
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=self.device)
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
n = token_dict_size
for x in embedding_weights:
@@ -106,24 +107,34 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(self.device)
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
self.transformer.set_input_embeddings(backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
if backup_embeds.weight.dtype != torch.float32:
print("autocast clip")
precision_scope = torch.autocast
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
precision_scope = contextlib.nullcontext
print("no autocast clip")
pooled_output = outputs.pooler_output
if self.text_projection is not None:
pooled_output = pooled_output @ self.text_projection
return z, pooled_output
with precision_scope(model_management.get_autocast_device(device)):
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
pooled_output = outputs.pooler_output
if self.text_projection is not None:
pooled_output = pooled_output @ self.text_projection
return z.float(), pooled_output.float()
def encode(self, tokens):
return self(tokens)