diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 1885d9730..dedfb47e2 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -4,7 +4,7 @@ import math import torch import torch.nn as nn -from einops import repeat +from einops import rearrange from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.layers import EmbedND @@ -153,7 +153,10 @@ def repeat_e(e, x): repeats = x.size(1) // e.size(1) if repeats == 1: return e - return torch.repeat_interleave(e, repeats, dim=1) + if repeats * e.size(1) == x.size(1): + return torch.repeat_interleave(e, repeats, dim=1) + else: + return torch.repeat_interleave(e, repeats + 1, dim=1)[:, :x.size(1)] class WanAttentionBlock(nn.Module): @@ -573,6 +576,28 @@ class WanModel(torch.nn.Module): x = self.unpatchify(x, grid_sizes) return x + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None): + patch_size = self.patch_size + t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) + h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) + w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) + + if steps_t is None: + steps_t = t_len + if steps_h is None: + steps_h = h_len + if steps_w is None: + steps_w = w_len + + img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) + img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) + + freqs = self.rope_embedder(img_ids).movedim(1, 2) + return freqs + def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, @@ -584,26 +609,16 @@ class WanModel(torch.nn.Module): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) - patch_size = self.patch_size - t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) - h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) - w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) - + t_len = t if time_dim_concat is not None: time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) x = torch.cat([x, time_dim_concat], dim=2) - t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0]) + t_len = x.shape[2] if self.ref_conv is not None and "reference_latent" in kwargs: t_len += 1 - img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) - img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) - img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) - img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) - - freqs = self.rope_embedder(img_ids).movedim(1, 2) + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype) return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): @@ -839,3 +854,466 @@ class CameraWanModel(WanModel): # unpatchify x = self.unpatchify(x, grid_sizes) return x + + +class CausalConv1d(nn.Module): + + def __init__(self, + chan_in, + chan_out, + kernel_size=3, + stride=1, + dilation=1, + pad_mode='replicate', + operations=None, + **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = operations.Conv1d( + chan_in, + chan_out, + kernel_size, + stride=stride, + dilation=dilation, + **kwargs) + + def forward(self, x): + x = torch.nn.functional.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class MotionEncoder_tc(nn.Module): + + def __init__(self, + in_dim: int, + hidden_dim: int, + num_heads=int, + need_global=True, + dtype=None, + device=None, + operations=None,): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.need_global = need_global + self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1, operations=operations, **factory_kwargs) + if need_global: + self.conv1_global = CausalConv1d( + in_dim, hidden_dim // 4, 3, stride=1, operations=operations, **factory_kwargs) + self.norm1 = operations.LayerNorm( + hidden_dim // 4, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2, operations=operations, **factory_kwargs) + self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2, operations=operations, **factory_kwargs) + + if need_global: + self.final_linear = operations.Linear(hidden_dim, hidden_dim, **factory_kwargs) + + self.norm1 = operations.LayerNorm( + hidden_dim // 4, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) + + self.norm2 = operations.LayerNorm( + hidden_dim // 2, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) + + self.norm3 = operations.LayerNorm( + hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs)) + + def forward(self, x): + x = rearrange(x, 'b t c -> b c t') + x_ori = x.clone() + b, c, t = x.shape + x = self.conv1_local(x) + x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads) + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + if not self.need_global: + return x_local + + x = self.conv1_global(x_ori) + x = rearrange(x, 'b c t -> b t c') + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = self.final_linear(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + + return x, x_local + + +class CausalAudioEncoder(nn.Module): + + def __init__(self, + dim=5120, + num_layers=25, + out_dim=2048, + video_rate=8, + num_token=4, + need_global=False, + dtype=None, + device=None, + operations=None): + super().__init__() + self.encoder = MotionEncoder_tc( + in_dim=dim, + hidden_dim=out_dim, + num_heads=num_token, + need_global=need_global, dtype=dtype, device=device, operations=operations) + weight = torch.empty((1, num_layers, 1, 1), dtype=dtype, device=device) + + self.weights = torch.nn.Parameter(weight) + self.act = torch.nn.SiLU() + + def forward(self, features): + # features B * num_layers * dim * video_length + weights = self.act(comfy.model_management.cast_to(self.weights, dtype=features.dtype, device=features.device)) + weights_sum = weights.sum(dim=1, keepdims=True) + weighted_feat = ((features * weights) / weights_sum).sum( + dim=1) # b dim f + weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim + res = self.encoder(weighted_feat) # b f n dim + return res # b f n dim + + +class AdaLayerNorm(nn.Module): + def __init__(self, embedding_dim, output_dim=None, norm_elementwise_affine=False, norm_eps=1e-5, dtype=None, device=None, operations=None): + super().__init__() + + output_dim = output_dim or embedding_dim * 2 + + self.silu = nn.SiLU() + self.linear = operations.Linear(embedding_dim, output_dim, dtype=dtype, device=device) + self.norm = operations.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine, dtype=dtype, device=device) + + def forward(self, x, temb): + temb = self.linear(self.silu(temb)) + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + x = self.norm(x) * (1 + scale) + shift + return x + + +class AudioInjector_WAN(nn.Module): + + def __init__(self, + dim=2048, + num_heads=32, + inject_layer=[0, 27], + root_net=None, + enable_adain=False, + adain_dim=2048, + adain_mode=None, + dtype=None, + device=None, + operations=None): + super().__init__() + self.enable_adain = enable_adain + self.adain_mode = adain_mode + self.injected_block_id = {} + audio_injector_id = 0 + for inject_id in inject_layer: + self.injected_block_id[inject_id] = audio_injector_id + audio_injector_id += 1 + + self.injector = nn.ModuleList([ + WanT2VCrossAttention( + dim=dim, + num_heads=num_heads, + qk_norm=True, operation_settings={"operations": operations, "device": device, "dtype": dtype} + ) for _ in range(audio_injector_id) + ]) + self.injector_pre_norm_feat = nn.ModuleList([ + operations.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, dtype=dtype, device=device + ) for _ in range(audio_injector_id) + ]) + self.injector_pre_norm_vec = nn.ModuleList([ + operations.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, dtype=dtype, device=device + ) for _ in range(audio_injector_id) + ]) + if enable_adain: + self.injector_adain_layers = nn.ModuleList([ + AdaLayerNorm( + output_dim=dim * 2, embedding_dim=adain_dim, dtype=dtype, device=device, operations=operations) + for _ in range(audio_injector_id) + ]) + if adain_mode != "attn_norm": + self.injector_adain_output_layers = nn.ModuleList( + [operations.Linear(dim, dim, dtype=dtype, device=device) for _ in range(audio_injector_id)]) + + def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len): + audio_attn_id = self.injected_block_id.get(block_id, None) + if audio_attn_id is None: + return x + + num_frames = audio_emb.shape[1] + input_hidden_states = rearrange(x[:, :seq_len], "b (t n) c -> (b t) n c", t=num_frames) + if self.enable_adain and self.adain_mode == "attn_norm": + audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c") + adain_hidden_states = self.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0]) + attn_hidden_states = adain_hidden_states + else: + attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states) + audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames) + attn_audio_emb = audio_emb + residual_out = self.injector[audio_attn_id](x=attn_hidden_states, context=attn_audio_emb) + residual_out = rearrange( + residual_out, "(b t) n c -> b (t n) c", t=num_frames) + x[:, :seq_len] = x[:, :seq_len] + residual_out + return x + + +class FramePackMotioner(nn.Module): + def __init__( + self, + inner_dim=1024, + num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design + zip_frame_buckets=[ + 1, 2, 16 + ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames + drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion + dtype=None, + device=None, + operations=None): + super().__init__() + self.proj = operations.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2), dtype=dtype, device=device) + self.proj_2x = operations.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4), dtype=dtype, device=device) + self.proj_4x = operations.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8), dtype=dtype, device=device) + self.zip_frame_buckets = zip_frame_buckets + + self.inner_dim = inner_dim + self.num_heads = num_heads + + self.drop_mode = drop_mode + + def forward(self, motion_latents, rope_embedder, add_last_motion=2): + lat_height, lat_width = motion_latents.shape[3], motion_latents.shape[4] + padd_lat = torch.zeros(motion_latents.shape[0], 16, sum(self.zip_frame_buckets), lat_height, lat_width).to(device=motion_latents.device, dtype=motion_latents.dtype) + overlap_frame = min(padd_lat.shape[2], motion_latents.shape[2]) + if overlap_frame > 0: + padd_lat[:, :, -overlap_frame:] = motion_latents[:, :, -overlap_frame:] + + if add_last_motion < 2 and self.drop_mode != "drop": + zero_end_frame = sum(self.zip_frame_buckets[:len(self.zip_frame_buckets) - add_last_motion - 1]) + padd_lat[:, :, -zero_end_frame:] = 0 + + clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -sum(self.zip_frame_buckets):, :, :].split(self.zip_frame_buckets[::-1], dim=2) # 16, 2 ,1 + + # patchfy + clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2) + clean_latents_2x = self.proj_2x(clean_latents_2x) + l_2x_shape = clean_latents_2x.shape + clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2) + clean_latents_4x = self.proj_4x(clean_latents_4x) + l_4x_shape = clean_latents_4x.shape + clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2) + + if add_last_motion < 2 and self.drop_mode == "drop": + clean_latents_post = clean_latents_post[:, : + 0] if add_last_motion < 2 else clean_latents_post + clean_latents_2x = clean_latents_2x[:, : + 0] if add_last_motion < 1 else clean_latents_2x + + motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1) + + rope_post = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-1, device=motion_latents.device, dtype=motion_latents.dtype) + rope_2x = rope_embedder.rope_encode(1, lat_height, lat_width, t_start=-3, steps_h=l_2x_shape[-2], steps_w=l_2x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype) + rope_4x = rope_embedder.rope_encode(4, lat_height, lat_width, t_start=-19, steps_h=l_4x_shape[-2], steps_w=l_4x_shape[-1], device=motion_latents.device, dtype=motion_latents.dtype) + + rope = torch.cat([rope_post, rope_2x, rope_4x], dim=1) + return motion_lat, rope + + +class WanModel_S2V(WanModel): + def __init__(self, + model_type='s2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + audio_dim=1024, + num_audio_token=4, + enable_adain=True, + cond_dim=16, + audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39], + adain_mode="attn_norm", + framepack_drop_mode="padd", + image_model=None, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, device=device, dtype=dtype, operations=operations) + + self.trainable_cond_mask = operations.Embedding(3, self.dim, device=device, dtype=dtype) + + self.casual_audio_encoder = CausalAudioEncoder( + dim=audio_dim, + out_dim=self.dim, + num_token=num_audio_token, + need_global=enable_adain, dtype=dtype, device=device, operations=operations) + + if cond_dim > 0: + self.cond_encoder = operations.Conv3d( + cond_dim, + self.dim, + kernel_size=self.patch_size, + stride=self.patch_size, device=device, dtype=dtype) + + self.audio_injector = AudioInjector_WAN( + dim=self.dim, + num_heads=self.num_heads, + inject_layer=audio_inject_layers, + root_net=self, + enable_adain=enable_adain, + adain_dim=self.dim, + adain_mode=adain_mode, + dtype=dtype, device=device, operations=operations + ) + + self.frame_packer = FramePackMotioner( + inner_dim=self.dim, + num_heads=self.num_heads, + zip_frame_buckets=[1, 2, 16], + drop_mode=framepack_drop_mode, + dtype=dtype, device=device, operations=operations) + + def forward_orig( + self, + x, + t, + context, + audio_embed=None, + reference_latent=None, + control_video=None, + reference_motion=None, + clip_fea=None, + freqs=None, + transformer_options={}, + **kwargs, + ): + if audio_embed is not None: + num_embeds = x.shape[-3] * 4 + audio_emb_global, audio_emb = self.casual_audio_encoder(audio_embed[:, :, :, :num_embeds]) + else: + audio_emb = None + + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + if control_video is not None: + x = x + self.cond_encoder(control_video) + + if t.ndim == 1: + t = t.unsqueeze(1).repeat(1, x.shape[2]) + + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + seq_len = x.size(1) + + cond_mask_weight = comfy.model_management.cast_to(self.trainable_cond_mask.weight, dtype=x.dtype, device=x.device).unsqueeze(1).unsqueeze(1) + x = x + cond_mask_weight[0] + + if reference_latent is not None: + ref = self.patch_embedding(reference_latent.float()).to(x.dtype) + ref = ref.flatten(2).transpose(1, 2) + freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=30, device=x.device, dtype=x.dtype) + ref = ref + cond_mask_weight[1] + x = torch.cat([x, ref], dim=1) + freqs = torch.cat([freqs, freqs_ref], dim=1) + t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1) + + if reference_motion is not None: + motion_encoded, freqs_motion = self.frame_packer(reference_motion, self) + motion_encoded = motion_encoded + cond_mask_weight[2] + x = torch.cat([x, motion_encoded], dim=1) + freqs = torch.cat([freqs, freqs_motion], dim=1) + + t = torch.repeat_interleave(t, 2, dim=1) + t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # context + context = self.text_embedding(context) + + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context) + if audio_emb is not None: + x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len) + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 6c861b15e..18d55c1c4 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1201,6 +1201,29 @@ class WAN21_Camera(WAN21): out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions) return out +class WAN22_S2V(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + audio_embed = kwargs.get("audio_embed", None) + if audio_embed is not None: + out['audio_embed'] = comfy.conds.CONDRegular(audio_embed) + + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])) + + reference_motion = kwargs.get("reference_motion", None) + if reference_motion is not None: + out['reference_motion'] = comfy.conds.CONDRegular(self.process_latent_in(reference_motion)) + + control_video = kwargs.get("control_video", None) + if control_video is not None: + out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video)) + return out + class WAN22(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 0caff53e0..9f3ab64df 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -368,6 +368,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "camera" else: dit_config["model_type"] = "camera_2.2" + elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "s2v" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 7ed6dfd69..ce571e6cb 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1072,6 +1072,19 @@ class WAN21_Vace(WAN21_T2V): out = model_base.WAN21_Vace(self, image_to_video=False, device=device) return out +class WAN22_S2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "s2v", + } + + def __init__(self, unet_config): + super().__init__(unet_config) + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN22_S2V(self, device=device) + return out + class WAN22_T2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1272,6 +1285,6 @@ class QwenImage(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 0fff02f76..89ff74d85 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -786,6 +786,180 @@ class WanTrackToVideo(io.ComfyNode): return io.NodeOutput(positive, negative, out_latent) +def linear_interpolation(features, input_fps, output_fps, output_len=None): + """ + features: shape=[1, T, 512] + input_fps: fps for audio, f_a + output_fps: fps for video, f_m + output_len: video length + """ + features = features.transpose(1, 2) # [1, 512, T] + seq_len = features.shape[2] / float(input_fps) # T/f_a + if output_len is None: + output_len = int(seq_len * output_fps) # f_m*T/f_a + output_features = torch.nn.functional.interpolate( + features, size=output_len, align_corners=True, + mode='linear') # [1, 512, output_len] + return output_features.transpose(1, 2) # [1, output_len, 512] + + +def get_sample_indices(original_fps, + total_frames, + target_fps, + num_sample, + fixed_start=None): + required_duration = num_sample / target_fps + required_origin_frames = int(np.ceil(required_duration * original_fps)) + if required_duration > total_frames / original_fps: + raise ValueError("required_duration must be less than video length") + + if not fixed_start is None and fixed_start >= 0: + start_frame = fixed_start + else: + max_start = total_frames - required_origin_frames + if max_start < 0: + raise ValueError("video length is too short") + start_frame = np.random.randint(0, max_start + 1) + start_time = start_frame / original_fps + + end_time = start_time + required_duration + time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) + + frame_indices = np.round(np.array(time_points) * original_fps).astype(int) + frame_indices = np.clip(frame_indices, 0, total_frames - 1) + return frame_indices + + +def get_audio_embed_bucket_fps(audio_embed, fps=16, batch_frames=81, m=0, video_rate=30): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + scale = video_rate / fps + + min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1 + + bucket_num = min_batch_num * batch_frames + padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * video_rate) - audio_frame_num + batch_idx = get_sample_indices( + original_fps=video_rate, + total_frames=audio_frame_num + padd_audio_num, + target_fps=fps, + num_sample=bucket_num, + fixed_start=0) + batch_audio_eb = [] + audio_sample_stride = int(video_rate / fps) + for bi in batch_idx: + if bi < audio_frame_num: + + chosen_idx = list( + range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride)) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [ + audio_frame_num - 1 if c >= audio_frame_num else c + for c in chosen_idx + ] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten( + start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ + else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + return batch_audio_eb, min_batch_num + + +class WanSoundImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanSoundImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.AudioEncoderOutput.Input("audio_encoder_output", optional=True), + io.Image.Input("ref_image", optional=True), + io.Image.Input("control_video", optional=True), + io.Image.Input("ref_motion", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None) -> io.NodeOutput: + latent_t = ((length - 1) // 4) + 1 + if audio_encoder_output is not None: + feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"]) + video_rate = 30 + fps = 16 + feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate) + audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=latent_t * 4, m=0, video_rate=video_rate) + audio_embed_bucket = audio_embed_bucket.unsqueeze(0) + if len(audio_embed_bucket.shape) == 3: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) + elif len(audio_embed_bucket.shape) == 4: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) + + positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket}) + negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket}) + + if ref_image is not None: + ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(ref_image[:, :, :, :3]) + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) + + if ref_motion is not None: + if ref_motion.shape[0] > 73: + ref_motion = ref_motion[-73:] + + ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + if ref_motion.shape[0] < 73: + r = torch.ones([73, height, width, 3]) * 0.5 + r[-ref_motion.shape[0]:] = ref_motion + ref_motion = r + + ref_motion = vae.encode(ref_motion[:, :, :, :3]) + positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion}) + negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion}) + + latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + + control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent)) + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + control_video = vae.encode(control_video[:, :, :, :3]) + control_video_out[:, :, :control_video.shape[2]] = control_video + + # TODO: check if zero is better than none if none provided + positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out}) + negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + + class Wan22ImageToVideoLatent(io.ComfyNode): @classmethod def define_schema(cls): @@ -844,6 +1018,7 @@ class WanExtension(ComfyExtension): TrimVideoLatent, WanCameraImageToVideo, WanPhantomSubjectToVideo, + WanSoundImageToVideo, Wan22ImageToVideoLatent, ]