mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 12:37:01 +00:00
WIP support for Wan I2V model.
This commit is contained in:
@@ -10,6 +10,7 @@ from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
# preprocess
|
||||
@@ -37,7 +38,7 @@ class WanRMSNorm(nn.Module):
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, C]
|
||||
"""
|
||||
return self._norm(x.float()).type_as(x) * self.weight
|
||||
return self._norm(x.float()).type_as(x) * comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device)
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||
@@ -125,7 +126,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
eps=1e-6, operation_settings={}):
|
||||
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
||||
super().__init__(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings)
|
||||
|
||||
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
@@ -218,7 +219,7 @@ class WanAttentionBlock(nn.Module):
|
||||
"""
|
||||
# assert e.dtype == torch.float32
|
||||
|
||||
e = (self.modulation + e).chunk(6, dim=1)
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
# assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
@@ -263,7 +264,7 @@ class Head(nn.Module):
|
||||
e(Tensor): Shape [B, C]
|
||||
"""
|
||||
# assert e.dtype == torch.float32
|
||||
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
||||
return x
|
||||
|
||||
@@ -401,7 +402,6 @@ class WanModel(torch.nn.Module):
|
||||
t,
|
||||
context,
|
||||
clip_fea=None,
|
||||
y=None,
|
||||
freqs=None,
|
||||
):
|
||||
r"""
|
||||
@@ -425,12 +425,6 @@ class WanModel(torch.nn.Module):
|
||||
List[Tensor]:
|
||||
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
if self.model_type == 'i2v':
|
||||
assert clip_fea is not None and y is not None
|
||||
|
||||
if y is not None:
|
||||
x = torch.cat([x, y], dim=0)
|
||||
|
||||
# embeddings
|
||||
x = self.patch_embedding(x)
|
||||
grid_sizes = x.shape[2:]
|
||||
@@ -465,7 +459,7 @@ class WanModel(torch.nn.Module):
|
||||
return x
|
||||
# return [u.float() for u in x]
|
||||
|
||||
def forward(self, x, timestep, context, y=None, image=None, **kwargs):
|
||||
def forward(self, x, timestep, context, clip_fea=None, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
patch_size = self.patch_size
|
||||
@@ -479,7 +473,7 @@ class WanModel(torch.nn.Module):
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=y, y=image, freqs=freqs)[:, :, :t, :h, :w]
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
r"""
|
||||
|
Reference in New Issue
Block a user