mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
WIP Wan 2.2 S2V model. (#9568)
This commit is contained in:
@@ -4,7 +4,7 @@ import math
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from einops import repeat
|
from einops import rearrange
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
@@ -153,7 +153,10 @@ def repeat_e(e, x):
|
|||||||
repeats = x.size(1) // e.size(1)
|
repeats = x.size(1) // e.size(1)
|
||||||
if repeats == 1:
|
if repeats == 1:
|
||||||
return e
|
return e
|
||||||
return torch.repeat_interleave(e, repeats, dim=1)
|
if repeats * e.size(1) == x.size(1):
|
||||||
|
return torch.repeat_interleave(e, repeats, dim=1)
|
||||||
|
else:
|
||||||
|
return torch.repeat_interleave(e, repeats + 1, dim=1)[:, :x.size(1)]
|
||||||
|
|
||||||
|
|
||||||
class WanAttentionBlock(nn.Module):
|
class WanAttentionBlock(nn.Module):
|
||||||
@@ -573,6 +576,28 @@ class WanModel(torch.nn.Module):
|
|||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
|
||||||
|
patch_size = self.patch_size
|
||||||
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||||
|
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||||
|
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||||
|
|
||||||
|
if steps_t is None:
|
||||||
|
steps_t = t_len
|
||||||
|
if steps_h is None:
|
||||||
|
steps_h = h_len
|
||||||
|
if steps_w is None:
|
||||||
|
steps_w = w_len
|
||||||
|
|
||||||
|
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||||
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||||
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||||
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||||
|
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||||
|
|
||||||
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
|
return freqs
|
||||||
|
|
||||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
self._forward,
|
self._forward,
|
||||||
@@ -584,26 +609,16 @@ class WanModel(torch.nn.Module):
|
|||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
|
|
||||||
patch_size = self.patch_size
|
t_len = t
|
||||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
|
||||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
|
||||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
|
||||||
|
|
||||||
if time_dim_concat is not None:
|
if time_dim_concat is not None:
|
||||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||||
x = torch.cat([x, time_dim_concat], dim=2)
|
x = torch.cat([x, time_dim_concat], dim=2)
|
||||||
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
|
t_len = x.shape[2]
|
||||||
|
|
||||||
if self.ref_conv is not None and "reference_latent" in kwargs:
|
if self.ref_conv is not None and "reference_latent" in kwargs:
|
||||||
t_len += 1
|
t_len += 1
|
||||||
|
|
||||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
|
||||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
|
||||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
|
||||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
|
||||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
|
||||||
|
|
||||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
@@ -839,3 +854,466 @@ class CameraWanModel(WanModel):
|
|||||||
# unpatchify
|
# unpatchify
|
||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv1d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
chan_in,
|
||||||
|
chan_out,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
pad_mode='replicate',
|
||||||
|
operations=None,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
padding = (kernel_size - 1, 0) # T
|
||||||
|
self.time_causal_padding = padding
|
||||||
|
|
||||||
|
self.conv = operations.Conv1d(
|
||||||
|
chan_in,
|
||||||
|
chan_out,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.nn.functional.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MotionEncoder_tc(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
num_heads=int,
|
||||||
|
need_global=True,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,):
|
||||||
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.need_global = need_global
|
||||||
|
self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1, operations=operations, **factory_kwargs)
|
||||||
|
if need_global:
|
||||||
|
self.conv1_global = CausalConv1d(
|
||||||
|
in_dim, hidden_dim // 4, 3, stride=1, operations=operations, **factory_kwargs)
|
||||||
|
self.norm1 = operations.LayerNorm(
|
||||||
|
hidden_dim // 4,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6,
|
||||||
|
**factory_kwargs)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2, operations=operations, **factory_kwargs)
|
||||||
|
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2, operations=operations, **factory_kwargs)
|
||||||
|
|
||||||
|
if need_global:
|
||||||
|
self.final_linear = operations.Linear(hidden_dim, hidden_dim, **factory_kwargs)
|
||||||
|
|
||||||
|
self.norm1 = operations.LayerNorm(
|
||||||
|
hidden_dim // 4,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6,
|
||||||
|
**factory_kwargs)
|
||||||
|
|
||||||
|
self.norm2 = operations.LayerNorm(
|
||||||
|
hidden_dim // 2,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6,
|
||||||
|
**factory_kwargs)
|
||||||
|
|
||||||
|
self.norm3 = operations.LayerNorm(
|
||||||
|
hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
|
||||||
|
self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = rearrange(x, 'b t c -> b c t')
|
||||||
|
x_ori = x.clone()
|
||||||
|
b, c, t = x.shape
|
||||||
|
x = self.conv1_local(x)
|
||||||
|
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, 'b t c -> b c t')
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = rearrange(x, 'b c t -> b t c')
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, 'b t c -> b c t')
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = rearrange(x, 'b c t -> b t c')
|
||||||
|
x = self.norm3(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
||||||
|
padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1)
|
||||||
|
x = torch.cat([x, padding], dim=-2)
|
||||||
|
x_local = x.clone()
|
||||||
|
|
||||||
|
if not self.need_global:
|
||||||
|
return x_local
|
||||||
|
|
||||||
|
x = self.conv1_global(x_ori)
|
||||||
|
x = rearrange(x, 'b c t -> b t c')
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, 'b t c -> b c t')
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = rearrange(x, 'b c t -> b t c')
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, 'b t c -> b c t')
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = rearrange(x, 'b c t -> b t c')
|
||||||
|
x = self.norm3(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.final_linear(x)
|
||||||
|
x = rearrange(x, '(b n) t c -> b t n c', b=b)
|
||||||
|
|
||||||
|
return x, x_local
|
||||||
|
|
||||||
|
|
||||||
|
class CausalAudioEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim=5120,
|
||||||
|
num_layers=25,
|
||||||
|
out_dim=2048,
|
||||||
|
video_rate=8,
|
||||||
|
num_token=4,
|
||||||
|
need_global=False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = MotionEncoder_tc(
|
||||||
|
in_dim=dim,
|
||||||
|
hidden_dim=out_dim,
|
||||||
|
num_heads=num_token,
|
||||||
|
need_global=need_global, dtype=dtype, device=device, operations=operations)
|
||||||
|
weight = torch.empty((1, num_layers, 1, 1), dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.weights = torch.nn.Parameter(weight)
|
||||||
|
self.act = torch.nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, features):
|
||||||
|
# features B * num_layers * dim * video_length
|
||||||
|
weights = self.act(comfy.model_management.cast_to(self.weights, dtype=features.dtype, device=features.device))
|
||||||
|
weights_sum = weights.sum(dim=1, keepdims=True)
|
||||||
|
weighted_feat = ((features * weights) / weights_sum).sum(
|
||||||
|
dim=1) # b dim f
|
||||||
|
weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
|
||||||
|
res = self.encoder(weighted_feat) # b f n dim
|
||||||
|
return res # b f n dim
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNorm(nn.Module):
|
||||||
|
def __init__(self, embedding_dim, output_dim=None, norm_elementwise_affine=False, norm_eps=1e-5, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
output_dim = output_dim or embedding_dim * 2
|
||||||
|
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = operations.Linear(embedding_dim, output_dim, dtype=dtype, device=device)
|
||||||
|
self.norm = operations.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x, temb):
|
||||||
|
temb = self.linear(self.silu(temb))
|
||||||
|
shift, scale = temb.chunk(2, dim=1)
|
||||||
|
shift = shift[:, None, :]
|
||||||
|
scale = scale[:, None, :]
|
||||||
|
x = self.norm(x) * (1 + scale) + shift
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AudioInjector_WAN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim=2048,
|
||||||
|
num_heads=32,
|
||||||
|
inject_layer=[0, 27],
|
||||||
|
root_net=None,
|
||||||
|
enable_adain=False,
|
||||||
|
adain_dim=2048,
|
||||||
|
adain_mode=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.enable_adain = enable_adain
|
||||||
|
self.adain_mode = adain_mode
|
||||||
|
self.injected_block_id = {}
|
||||||
|
audio_injector_id = 0
|
||||||
|
for inject_id in inject_layer:
|
||||||
|
self.injected_block_id[inject_id] = audio_injector_id
|
||||||
|
audio_injector_id += 1
|
||||||
|
|
||||||
|
self.injector = nn.ModuleList([
|
||||||
|
WanT2VCrossAttention(
|
||||||
|
dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qk_norm=True, operation_settings={"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
) for _ in range(audio_injector_id)
|
||||||
|
])
|
||||||
|
self.injector_pre_norm_feat = nn.ModuleList([
|
||||||
|
operations.LayerNorm(
|
||||||
|
dim,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6, dtype=dtype, device=device
|
||||||
|
) for _ in range(audio_injector_id)
|
||||||
|
])
|
||||||
|
self.injector_pre_norm_vec = nn.ModuleList([
|
||||||
|
operations.LayerNorm(
|
||||||
|
dim,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6, dtype=dtype, device=device
|
||||||
|
) for _ in range(audio_injector_id)
|
||||||
|
])
|
||||||
|
if enable_adain:
|
||||||
|
self.injector_adain_layers = nn.ModuleList([
|
||||||
|
AdaLayerNorm(
|
||||||
|
output_dim=dim * 2, embedding_dim=adain_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
for _ in range(audio_injector_id)
|
||||||
|
])
|
||||||
|
if adain_mode != "attn_norm":
|
||||||
|
self.injector_adain_output_layers = nn.ModuleList(
|
||||||
|
[operations.Linear(dim, dim, dtype=dtype, device=device) for _ in range(audio_injector_id)])
|
||||||
|
|
||||||
|
def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len):
|
||||||
|
audio_attn_id = self.injected_block_id.get(block_id, None)
|
||||||
|
if audio_attn_id is None:
|
||||||
|
return x
|
||||||
|
|
||||||
|
num_frames = audio_emb.shape[1]
|
||||||
|
input_hidden_states = rearrange(x[:, :seq_len], "b (t n) c -> (b t) n c", t=num_frames)
|
||||||
|
if self.enable_adain and self.adain_mode == "attn_norm":
|
||||||
|
audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c")
|
||||||
|
adain_hidden_states = self.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0])
|
||||||
|
attn_hidden_states = adain_hidden_states
|
||||||
|
else:
|
||||||
|
attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
|
||||||
|
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
|
||||||
|
attn_audio_emb = audio_emb
|
||||||
|
residual_out = self.injector[audio_attn_id](x=attn_hidden_states, context=attn_audio_emb)
|
||||||
|
residual_out = rearrange(
|
||||||
|
residual_out, "(b t) n c -> b (t n) c", t=num_frames)
|
||||||
|
x[:, :seq_len] = x[:, :seq_len] + residual_out
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FramePackMotioner(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inner_dim=1024,
|
||||||
|
num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design
|
||||||
|
zip_frame_buckets=[
|
||||||
|
1, 2, 16
|
||||||
|
], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames
|
||||||
|
drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = operations.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2), dtype=dtype, device=device)
|
||||||
|
self.proj_2x = operations.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4), dtype=dtype, device=device)
|
||||||
|
self.proj_4x = operations.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8), dtype=dtype, device=device)
|
||||||
|
self.zip_frame_buckets = zip_frame_buckets
|
||||||
|
|
||||||
|
self.inner_dim = inner_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
self.drop_mode = drop_mode
|
||||||
|
|
||||||
|
def forward(self, motion_latents, rope_embedder, add_last_motion=2):
|
||||||
|
lat_height, lat_width = motion_latents.shape[3], motion_latents.shape[4]
|
||||||
|
padd_lat = torch.zeros(motion_latents.shape[0], 16, sum(self.zip_frame_buckets), lat_height, lat_width).to(device=motion_latents.device, dtype=motion_latents.dtype)
|
||||||
|
overlap_frame = min(padd_lat.shape[2], motion_latents.shape[2])
|
||||||
|
if overlap_frame > 0:
|
||||||
|
padd_lat[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:]
|
||||||
|
|
||||||
|
if add_last_motion < 2 and self.drop_mode != "drop":
|
||||||
|
zero_end_frame = sum(self.zip_frame_buckets[:len(self.zip_frame_buckets) - add_last_motion - 1])
|
||||||
|
padd_lat[:, :, -zero_end_frame:] = 0
|
||||||
|
|
||||||
|
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -sum(self.zip_frame_buckets):, :, :].split(self.zip_frame_buckets[::-1], dim=2) # 16, 2 ,1
|
||||||
|
|
||||||
|
# patchfy
|
||||||
|
clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2)
|
||||||
|
clean_latents_2x = self.proj_2x(clean_latents_2x)
|
||||||
|
l_2x_shape = clean_latents_2x.shape
|
||||||
|
clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
|
||||||
|
clean_latents_4x = self.proj_4x(clean_latents_4x)
|
||||||
|
l_4x_shape = clean_latents_4x.shape
|
||||||
|
clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
if add_last_motion < 2 and self.drop_mode == "drop":
|
||||||
|
clean_latents_post = clean_latents_post[:, :
|
||||||
|
0] if add_last_motion < 2 else clean_latents_post
|
||||||
|
clean_latents_2x = clean_latents_2x[:, :
|
||||||
|
0] if add_last_motion < 1 else clean_latents_2x
|
||||||
|
|
||||||
|
motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
|
||||||
|
|
||||||
|
rope_post = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-1, device=motion_latents.device, dtype=motion_latents.dtype)
|
||||||
|
rope_2x = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-3, steps_h=l_2x_shape[-2], steps_w=l_2x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype)
|
||||||
|
rope_4x = rope_embedder.rope_encode(4, lat_height, lat_width, t_start=-19, steps_h=l_4x_shape[-2], steps_w=l_4x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype)
|
||||||
|
|
||||||
|
rope = torch.cat([rope_post, rope_2x, rope_4x], dim=1)
|
||||||
|
return motion_lat, rope
|
||||||
|
|
||||||
|
|
||||||
|
class WanModel_S2V(WanModel):
|
||||||
|
def __init__(self,
|
||||||
|
model_type='s2v',
|
||||||
|
patch_size=(1, 2, 2),
|
||||||
|
text_len=512,
|
||||||
|
in_dim=16,
|
||||||
|
dim=2048,
|
||||||
|
ffn_dim=8192,
|
||||||
|
freq_dim=256,
|
||||||
|
text_dim=4096,
|
||||||
|
out_dim=16,
|
||||||
|
num_heads=16,
|
||||||
|
num_layers=32,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
audio_dim=1024,
|
||||||
|
num_audio_token=4,
|
||||||
|
enable_adain=True,
|
||||||
|
cond_dim=16,
|
||||||
|
audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
|
||||||
|
adain_mode="attn_norm",
|
||||||
|
framepack_drop_mode="padd",
|
||||||
|
image_model=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.casual_audio_encoder = CausalAudioEncoder(
|
||||||
|
dim=audio_dim,
|
||||||
|
out_dim=self.dim,
|
||||||
|
num_token=num_audio_token,
|
||||||
|
need_global=enable_adain, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
if cond_dim > 0:
|
||||||
|
self.cond_encoder = operations.Conv3d(
|
||||||
|
cond_dim,
|
||||||
|
self.dim,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.audio_injector = AudioInjector_WAN(
|
||||||
|
dim=self.dim,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
inject_layer=audio_inject_layers,
|
||||||
|
root_net=self,
|
||||||
|
enable_adain=enable_adain,
|
||||||
|
adain_dim=self.dim,
|
||||||
|
adain_mode=adain_mode,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
self.frame_packer = FramePackMotioner(
|
||||||
|
inner_dim=self.dim,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
zip_frame_buckets=[1, 2, 16],
|
||||||
|
drop_mode=framepack_drop_mode,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
context,
|
||||||
|
audio_embed=None,
|
||||||
|
reference_latent=None,
|
||||||
|
control_video=None,
|
||||||
|
reference_motion=None,
|
||||||
|
clip_fea=None,
|
||||||
|
freqs=None,
|
||||||
|
transformer_options={},
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if audio_embed is not None:
|
||||||
|
num_embeds = x.shape[-3] * 4
|
||||||
|
audio_emb_global, audio_emb = self.casual_audio_encoder(audio_embed[:, :, :, :num_embeds])
|
||||||
|
else:
|
||||||
|
audio_emb = None
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
|
if control_video is not None:
|
||||||
|
x = x + self.cond_encoder(control_video)
|
||||||
|
|
||||||
|
if t.ndim == 1:
|
||||||
|
t = t.unsqueeze(1).repeat(1, x.shape[2])
|
||||||
|
|
||||||
|
grid_sizes = x.shape[2:]
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
seq_len = x.size(1)
|
||||||
|
|
||||||
|
cond_mask_weight = comfy.model_management.cast_to(self.trainable_cond_mask.weight, dtype=x.dtype, device=x.device).unsqueeze(1).unsqueeze(1)
|
||||||
|
x = x + cond_mask_weight[0]
|
||||||
|
|
||||||
|
if reference_latent is not None:
|
||||||
|
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
|
||||||
|
ref = ref.flatten(2).transpose(1, 2)
|
||||||
|
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=30, device=x.device, dtype=x.dtype)
|
||||||
|
ref = ref + cond_mask_weight[1]
|
||||||
|
x = torch.cat([x, ref], dim=1)
|
||||||
|
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
||||||
|
t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
|
||||||
|
|
||||||
|
if reference_motion is not None:
|
||||||
|
motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
|
||||||
|
motion_encoded = motion_encoded + cond_mask_weight[2]
|
||||||
|
x = torch.cat([x, motion_encoded], dim=1)
|
||||||
|
freqs = torch.cat([freqs, freqs_motion], dim=1)
|
||||||
|
|
||||||
|
t = torch.repeat_interleave(t, 2, dim=1)
|
||||||
|
t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
|
||||||
|
|
||||||
|
# time embeddings
|
||||||
|
e = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||||
|
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||||
|
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||||
|
|
||||||
|
# context
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context)
|
||||||
|
if audio_emb is not None:
|
||||||
|
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
|
||||||
|
# head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, grid_sizes)
|
||||||
|
return x
|
||||||
|
@@ -1201,6 +1201,29 @@ class WAN21_Camera(WAN21):
|
|||||||
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN22_S2V(WAN21):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
audio_embed = kwargs.get("audio_embed", None)
|
||||||
|
if audio_embed is not None:
|
||||||
|
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
|
||||||
|
|
||||||
|
reference_latents = kwargs.get("reference_latents", None)
|
||||||
|
if reference_latents is not None:
|
||||||
|
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
|
||||||
|
|
||||||
|
reference_motion = kwargs.get("reference_motion", None)
|
||||||
|
if reference_motion is not None:
|
||||||
|
out['reference_motion'] = comfy.conds.CONDRegular(self.process_latent_in(reference_motion))
|
||||||
|
|
||||||
|
control_video = kwargs.get("control_video", None)
|
||||||
|
if control_video is not None:
|
||||||
|
out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
|
||||||
|
return out
|
||||||
|
|
||||||
class WAN22(BaseModel):
|
class WAN22(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
|
@@ -368,6 +368,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["model_type"] = "camera"
|
dit_config["model_type"] = "camera"
|
||||||
else:
|
else:
|
||||||
dit_config["model_type"] = "camera_2.2"
|
dit_config["model_type"] = "camera_2.2"
|
||||||
|
elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["model_type"] = "s2v"
|
||||||
else:
|
else:
|
||||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "i2v"
|
dit_config["model_type"] = "i2v"
|
||||||
|
@@ -1072,6 +1072,19 @@ class WAN21_Vace(WAN21_T2V):
|
|||||||
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN22_S2V(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "s2v",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, unet_config):
|
||||||
|
super().__init__(unet_config)
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN22_S2V(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class WAN22_T2V(WAN21_T2V):
|
class WAN22_T2V(WAN21_T2V):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
@@ -1272,6 +1285,6 @@ class QwenImage(supported_models_base.BASE):
|
|||||||
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
@@ -786,6 +786,180 @@ class WanTrackToVideo(io.ComfyNode):
|
|||||||
return io.NodeOutput(positive, negative, out_latent)
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_interpolation(features, input_fps, output_fps, output_len=None):
|
||||||
|
"""
|
||||||
|
features: shape=[1, T, 512]
|
||||||
|
input_fps: fps for audio, f_a
|
||||||
|
output_fps: fps for video, f_m
|
||||||
|
output_len: video length
|
||||||
|
"""
|
||||||
|
features = features.transpose(1, 2) # [1, 512, T]
|
||||||
|
seq_len = features.shape[2] / float(input_fps) # T/f_a
|
||||||
|
if output_len is None:
|
||||||
|
output_len = int(seq_len * output_fps) # f_m*T/f_a
|
||||||
|
output_features = torch.nn.functional.interpolate(
|
||||||
|
features, size=output_len, align_corners=True,
|
||||||
|
mode='linear') # [1, 512, output_len]
|
||||||
|
return output_features.transpose(1, 2) # [1, output_len, 512]
|
||||||
|
|
||||||
|
|
||||||
|
def get_sample_indices(original_fps,
|
||||||
|
total_frames,
|
||||||
|
target_fps,
|
||||||
|
num_sample,
|
||||||
|
fixed_start=None):
|
||||||
|
required_duration = num_sample / target_fps
|
||||||
|
required_origin_frames = int(np.ceil(required_duration * original_fps))
|
||||||
|
if required_duration > total_frames / original_fps:
|
||||||
|
raise ValueError("required_duration must be less than video length")
|
||||||
|
|
||||||
|
if not fixed_start is None and fixed_start >= 0:
|
||||||
|
start_frame = fixed_start
|
||||||
|
else:
|
||||||
|
max_start = total_frames - required_origin_frames
|
||||||
|
if max_start < 0:
|
||||||
|
raise ValueError("video length is too short")
|
||||||
|
start_frame = np.random.randint(0, max_start + 1)
|
||||||
|
start_time = start_frame / original_fps
|
||||||
|
|
||||||
|
end_time = start_time + required_duration
|
||||||
|
time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
|
||||||
|
|
||||||
|
frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
|
||||||
|
frame_indices = np.clip(frame_indices, 0, total_frames - 1)
|
||||||
|
return frame_indices
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_embed_bucket_fps(audio_embed, fps=16, batch_frames=81, m=0, video_rate=30):
|
||||||
|
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
||||||
|
|
||||||
|
if num_layers > 1:
|
||||||
|
return_all_layers = True
|
||||||
|
else:
|
||||||
|
return_all_layers = False
|
||||||
|
|
||||||
|
scale = video_rate / fps
|
||||||
|
|
||||||
|
min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
|
||||||
|
|
||||||
|
bucket_num = min_batch_num * batch_frames
|
||||||
|
padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * video_rate) - audio_frame_num
|
||||||
|
batch_idx = get_sample_indices(
|
||||||
|
original_fps=video_rate,
|
||||||
|
total_frames=audio_frame_num + padd_audio_num,
|
||||||
|
target_fps=fps,
|
||||||
|
num_sample=bucket_num,
|
||||||
|
fixed_start=0)
|
||||||
|
batch_audio_eb = []
|
||||||
|
audio_sample_stride = int(video_rate / fps)
|
||||||
|
for bi in batch_idx:
|
||||||
|
if bi < audio_frame_num:
|
||||||
|
|
||||||
|
chosen_idx = list(
|
||||||
|
range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
|
||||||
|
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
|
||||||
|
chosen_idx = [
|
||||||
|
audio_frame_num - 1 if c >= audio_frame_num else c
|
||||||
|
for c in chosen_idx
|
||||||
|
]
|
||||||
|
|
||||||
|
if return_all_layers:
|
||||||
|
frame_audio_embed = audio_embed[:, chosen_idx].flatten(
|
||||||
|
start_dim=-2, end_dim=-1)
|
||||||
|
else:
|
||||||
|
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
|
||||||
|
else:
|
||||||
|
frame_audio_embed = torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
|
||||||
|
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
|
||||||
|
batch_audio_eb.append(frame_audio_embed)
|
||||||
|
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
|
||||||
|
|
||||||
|
return batch_audio_eb, min_batch_num
|
||||||
|
|
||||||
|
|
||||||
|
class WanSoundImageToVideo(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="WanSoundImageToVideo",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
|
||||||
|
io.Image.Input("ref_image", optional=True),
|
||||||
|
io.Image.Input("control_video", optional=True),
|
||||||
|
io.Image.Input("ref_motion", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None) -> io.NodeOutput:
|
||||||
|
latent_t = ((length - 1) // 4) + 1
|
||||||
|
if audio_encoder_output is not None:
|
||||||
|
feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"])
|
||||||
|
video_rate = 30
|
||||||
|
fps = 16
|
||||||
|
feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
|
||||||
|
audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=latent_t * 4, m=0, video_rate=video_rate)
|
||||||
|
audio_embed_bucket = audio_embed_bucket.unsqueeze(0)
|
||||||
|
if len(audio_embed_bucket.shape) == 3:
|
||||||
|
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
|
||||||
|
elif len(audio_embed_bucket.shape) == 4:
|
||||||
|
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket})
|
||||||
|
|
||||||
|
if ref_image is not None:
|
||||||
|
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
ref_latent = vae.encode(ref_image[:, :, :, :3])
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
|
||||||
|
|
||||||
|
if ref_motion is not None:
|
||||||
|
if ref_motion.shape[0] > 73:
|
||||||
|
ref_motion = ref_motion[-73:]
|
||||||
|
|
||||||
|
ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
|
||||||
|
if ref_motion.shape[0] < 73:
|
||||||
|
r = torch.ones([73, height, width, 3]) * 0.5
|
||||||
|
r[-ref_motion.shape[0]:] = ref_motion
|
||||||
|
ref_motion = r
|
||||||
|
|
||||||
|
ref_motion = vae.encode(ref_motion[:, :, :, :3])
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion})
|
||||||
|
|
||||||
|
latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent))
|
||||||
|
if control_video is not None:
|
||||||
|
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
control_video = vae.encode(control_video[:, :, :, :3])
|
||||||
|
control_video_out[:, :, :control_video.shape[2]] = control_video
|
||||||
|
|
||||||
|
# TODO: check if zero is better than none if none provided
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out})
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
class Wan22ImageToVideoLatent(io.ComfyNode):
|
class Wan22ImageToVideoLatent(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@@ -844,6 +1018,7 @@ class WanExtension(ComfyExtension):
|
|||||||
TrimVideoLatent,
|
TrimVideoLatent,
|
||||||
WanCameraImageToVideo,
|
WanCameraImageToVideo,
|
||||||
WanPhantomSubjectToVideo,
|
WanPhantomSubjectToVideo,
|
||||||
|
WanSoundImageToVideo,
|
||||||
Wan22ImageToVideoLatent,
|
Wan22ImageToVideoLatent,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user