Use common function for casting weights to input.

This commit is contained in:
comfyanonymous
2024-07-30 05:03:20 -04:00
parent 79040635da
commit 25853d0be8
7 changed files with 51 additions and 31 deletions

View File

@@ -4,6 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from torch.utils import checkpoint
@@ -234,7 +235,7 @@ class HunYuanDiT(nn.Module):
if self.use_style_cond:
# Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(1, hidden_size, dtype=dtype, device=device)
self.style_embedder = operations.Embedding(1, hidden_size, dtype=dtype, device=device)
self.extra_in_dim += hidden_size
# Text embedding for `add`
@@ -321,7 +322,7 @@ class HunYuanDiT(nn.Module):
b_t5, l_t5, c_t5 = text_states_t5.shape
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
padding = self.text_embedding_padding.to(text_states)
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
text_states[:,-self.text_len:] = torch.where(text_states_mask[:,-self.text_len:].unsqueeze(2), text_states[:,-self.text_len:], padding[:self.text_len])
text_states_t5[:,-self.text_len_t5:] = torch.where(text_states_t5_mask[:,-self.text_len_t5:].unsqueeze(2), text_states_t5[:,-self.text_len_t5:], padding[self.text_len:])