Use common function for casting weights to input.

This commit is contained in:
comfyanonymous
2024-07-30 05:03:20 -04:00
parent 79040635da
commit 25853d0be8
7 changed files with 51 additions and 31 deletions

View File

@@ -9,6 +9,7 @@ from einops import rearrange
from torch import nn
from torch.nn import functional as F
import math
import comfy.ops
class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
@@ -18,7 +19,7 @@ class FourierFeatures(nn.Module):
[out_features // 2, in_features], dtype=dtype, device=device))
def forward(self, input):
f = 2 * math.pi * input @ self.weight.T.to(dtype=input.dtype, device=input.device)
f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input)
return torch.cat([f.cos(), f.sin()], dim=-1)
# norms
@@ -38,9 +39,9 @@ class LayerNorm(nn.Module):
def forward(self, x):
beta = self.beta
if self.beta is not None:
beta = beta.to(dtype=x.dtype, device=x.device)
return F.layer_norm(x, x.shape[-1:], weight=self.gamma.to(dtype=x.dtype, device=x.device), bias=beta)
if beta is not None:
beta = comfy.ops.cast_to_input(beta, x)
return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta)
class GLU(nn.Module):
def __init__(
@@ -123,7 +124,9 @@ class RotaryEmbedding(nn.Module):
scale_base = 512,
interpolation_factor = 1.,
base = 10000,
base_rescale_factor = 1.
base_rescale_factor = 1.,
dtype=None,
device=None,
):
super().__init__()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
@@ -131,8 +134,8 @@ class RotaryEmbedding(nn.Module):
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
base *= base_rescale_factor ** (dim / (dim - 2))
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype))
assert interpolation_factor >= 1.
self.interpolation_factor = interpolation_factor
@@ -161,14 +164,14 @@ class RotaryEmbedding(nn.Module):
t = t / self.interpolation_factor
freqs = torch.einsum('i , j -> i j', t, self.inv_freq.to(dtype=dtype, device=device))
freqs = torch.einsum('i , j -> i j', t, comfy.ops.cast_to_input(self.inv_freq, t))
freqs = torch.cat((freqs, freqs), dim = -1)
if self.scale is None:
return freqs, 1.
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
scale = self.scale.to(dtype=dtype, device=device) ** rearrange(power, 'n -> n 1')
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
return freqs, scale
@@ -568,7 +571,7 @@ class ContinuousTransformer(nn.Module):
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
if rotary_pos_emb:
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype)
else:
self.rotary_pos_emb = None