mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 20:48:22 +00:00
Own BertModel implementation that works with lowvram.
This commit is contained in:
@@ -1,56 +1,15 @@
|
||||
from comfy import sd1_clip
|
||||
from transformers import T5TokenizerFast, BertTokenizer, BertModel, modeling_utils, BertConfig
|
||||
from transformers import BertTokenizer
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
from .bert import BertModel
|
||||
import comfy.text_encoders.t5
|
||||
import os
|
||||
|
||||
import torch
|
||||
import contextlib
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_comfy_ops(ops, device=None, dtype=None):
|
||||
old_torch_nn_linear = torch.nn.Linear
|
||||
force_device = device
|
||||
force_dtype = dtype
|
||||
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
||||
if force_device is not None:
|
||||
device = force_device
|
||||
if force_dtype is not None:
|
||||
dtype = force_dtype
|
||||
return ops.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
torch.nn.Linear = linear_with_dtype
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.nn.Linear = old_torch_nn_linear
|
||||
|
||||
|
||||
class RobertaWrapper(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = BertConfig(**config_dict)
|
||||
with use_comfy_ops(operations, device, dtype):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.bert = BertModel(config, add_pooling_layer=False)
|
||||
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.bert.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
return self.bert.set_input_embeddings(value)
|
||||
|
||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
intermediate = None
|
||||
out = self.bert(input_ids=input_tokens, output_hidden_states=intermediate_output is not None, attention_mask=attention_mask)
|
||||
return out.last_hidden_state, intermediate, out.pooler_output
|
||||
|
||||
class HyditBertModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=RobertaWrapper, enable_attention_masks=True, return_attention_masks=True)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)
|
||||
|
||||
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
|
Reference in New Issue
Block a user