Lower T5 memory usage by a few hundred MB.

This commit is contained in:
comfyanonymous
2024-07-31 00:52:34 -04:00
parent 82cae45d44
commit b85216a3c0
3 changed files with 33 additions and 17 deletions

View File

@@ -1,6 +1,7 @@
import torch
import math
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ops
class T5LayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=None):
@@ -11,7 +12,7 @@ class T5LayerNorm(torch.nn.Module):
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
return self.weight.to(device=x.device, dtype=x.dtype) * x
return comfy.ops.cast_to_input(self.weight, x) * x
activations = {
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
@@ -82,7 +83,7 @@ class T5Attention(torch.nn.Module):
if relative_attention_bias:
self.relative_attention_num_buckets = 32
self.relative_attention_max_distance = 128
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
self.relative_attention_bias = operations.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device, dtype=dtype)
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
@@ -132,7 +133,7 @@ class T5Attention(torch.nn.Module):
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, device):
def compute_bias(self, query_length, key_length, device, dtype):
"""Compute binned relative position bias"""
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
@@ -143,7 +144,7 @@ class T5Attention(torch.nn.Module):
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
@@ -152,7 +153,7 @@ class T5Attention(torch.nn.Module):
k = self.k(x)
v = self.v(x)
if self.relative_attention_bias is not None:
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device, x.dtype)
if past_bias is not None:
if mask is not None:
@@ -225,7 +226,7 @@ class T5(torch.nn.Module):
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
self.dtype = dtype
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)
self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype)
def get_input_embeddings(self):
return self.shared
@@ -234,5 +235,5 @@ class T5(torch.nn.Module):
self.shared = embeddings
def forward(self, input_ids, *args, **kwargs):
x = self.shared(input_ids)
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
return self.encoder(x, *args, **kwargs)