WIP support for Wan I2V model.

This commit is contained in:
comfyanonymous
2025-02-26 01:49:43 -05:00
parent cb06e9669b
commit 4ced06b879
6 changed files with 116 additions and 17 deletions

View File

@@ -10,6 +10,7 @@ from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.ldm.common_dit
import comfy.model_management
def sinusoidal_embedding_1d(dim, position):
# preprocess
@@ -37,7 +38,7 @@ class WanRMSNorm(nn.Module):
Args:
x(Tensor): Shape [B, L, C]
"""
return self._norm(x.float()).type_as(x) * self.weight
return self._norm(x.float()).type_as(x) * comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
@@ -125,7 +126,7 @@ class WanI2VCrossAttention(WanSelfAttention):
window_size=(-1, -1),
qk_norm=True,
eps=1e-6, operation_settings={}):
super().__init__(dim, num_heads, window_size, qk_norm, eps)
super().__init__(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings)
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
@@ -218,7 +219,7 @@ class WanAttentionBlock(nn.Module):
"""
# assert e.dtype == torch.float32
e = (self.modulation + e).chunk(6, dim=1)
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
# assert e[0].dtype == torch.float32
# self-attention
@@ -263,7 +264,7 @@ class Head(nn.Module):
e(Tensor): Shape [B, C]
"""
# assert e.dtype == torch.float32
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x
@@ -401,7 +402,6 @@ class WanModel(torch.nn.Module):
t,
context,
clip_fea=None,
y=None,
freqs=None,
):
r"""
@@ -425,12 +425,6 @@ class WanModel(torch.nn.Module):
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
if y is not None:
x = torch.cat([x, y], dim=0)
# embeddings
x = self.patch_embedding(x)
grid_sizes = x.shape[2:]
@@ -465,7 +459,7 @@ class WanModel(torch.nn.Module):
return x
# return [u.float() for u in x]
def forward(self, x, timestep, context, y=None, image=None, **kwargs):
def forward(self, x, timestep, context, clip_fea=None, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size
@@ -479,7 +473,7 @@ class WanModel(torch.nn.Module):
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
freqs = self.rope_embedder(img_ids).movedim(1, 2)
return self.forward_orig(x, timestep, context, clip_fea=y, y=image, freqs=freqs)[:, :, :t, :h, :w]
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
r"""