Use common function for casting weights to input.

This commit is contained in:
comfyanonymous
2024-07-30 05:03:20 -04:00
parent 79040635da
commit 25853d0be8
7 changed files with 51 additions and 31 deletions

View File

@@ -8,6 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
@@ -427,7 +428,7 @@ class MMDiT(nn.Module):
max_dim = max(h, w)
cur_dim = self.h_max
pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype)
pos_encoding = comfy.ops.cast_to_input(self.positional_encoding.reshape(1, cur_dim, cur_dim, -1), x)
if max_dim > cur_dim:
pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
@@ -455,7 +456,7 @@ class MMDiT(nn.Module):
t = timestep
c = self.cond_seq_linear(c_seq) # B, T_c, D
c = torch.cat([self.register_tokens.to(device=c.device, dtype=c.dtype).repeat(c.size(0), 1, 1), c], dim=1)
c = torch.cat([comfy.ops.cast_to_input(self.register_tokens, c).repeat(c.size(0), 1, 1), c], dim=1)
global_cond = self.t_embedder(t, x.dtype) # B, D