Own BertModel implementation that works with lowvram.

This commit is contained in:
comfyanonymous
2024-07-26 04:32:33 -04:00
parent 25b51b1a8b
commit a9ac56fc0d
2 changed files with 142 additions and 44 deletions

View File

@@ -1,56 +1,15 @@
from comfy import sd1_clip
from transformers import T5TokenizerFast, BertTokenizer, BertModel, modeling_utils, BertConfig
from transformers import BertTokenizer
from .spiece_tokenizer import SPieceTokenizer
from .bert import BertModel
import comfy.text_encoders.t5
import os
import torch
import contextlib
@contextlib.contextmanager
def use_comfy_ops(ops, device=None, dtype=None):
old_torch_nn_linear = torch.nn.Linear
force_device = device
force_dtype = dtype
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
if force_device is not None:
device = force_device
if force_dtype is not None:
dtype = force_dtype
return ops.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
torch.nn.Linear = linear_with_dtype
try:
yield
finally:
torch.nn.Linear = old_torch_nn_linear
class RobertaWrapper(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = BertConfig(**config_dict)
with use_comfy_ops(operations, device, dtype):
with modeling_utils.no_init_weights():
self.bert = BertModel(config, add_pooling_layer=False)
self.num_layers = config.num_hidden_layers
def get_input_embeddings(self):
return self.bert.get_input_embeddings()
def set_input_embeddings(self, value):
return self.bert.set_input_embeddings(value)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
intermediate = None
out = self.bert(input_ids=input_tokens, output_hidden_states=intermediate_output is not None, attention_mask=attention_mask)
return out.last_hidden_state, intermediate, out.pooler_output
class HyditBertModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=RobertaWrapper, enable_attention_masks=True, return_attention_masks=True)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)
class HyditBertTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):