From edde0b50431e296f61f79205e25cb01f653013a2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 28 Aug 2025 14:59:48 -0700 Subject: [PATCH] WanSoundImageToVideoExtend node to manually extend s2v video. (#9606) --- comfy_extras/nodes_wan.py | 145 +++++++++++++++++++++++++------------- 1 file changed, 97 insertions(+), 48 deletions(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 312260f00..0a55bd5d0 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -877,6 +877,67 @@ def get_audio_embed_bucket_fps(audio_embed, fps=16, batch_frames=81, m=0, video_ return batch_audio_eb, min_batch_num +def wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=0, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None, ref_motion_latent=None): + latent_t = ((length - 1) // 4) + 1 + if audio_encoder_output is not None: + feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"]) + video_rate = 30 + fps = 16 + feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate) + batch_frames = latent_t * 4 + audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=batch_frames, m=0, video_rate=video_rate) + audio_embed_bucket = audio_embed_bucket.unsqueeze(0) + if len(audio_embed_bucket.shape) == 3: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) + elif len(audio_embed_bucket.shape) == 4: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) + + audio_embed_bucket = audio_embed_bucket[:, :, :, frame_offset:frame_offset + batch_frames] + positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket}) + negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket * 0.0}) + frame_offset += batch_frames + + if ref_image is not None: + ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(ref_image[:, :, :, :3]) + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) + + if ref_motion is not None: + if ref_motion.shape[0] > 73: + ref_motion = ref_motion[-73:] + + ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + if ref_motion.shape[0] < 73: + r = torch.ones([73, height, width, 3]) * 0.5 + r[-ref_motion.shape[0]:] = ref_motion + ref_motion = r + + ref_motion_latent = vae.encode(ref_motion[:, :, :, :3]) + + if ref_motion_latent is not None: + ref_motion_latent = ref_motion_latent[:, :, -19:] + positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion_latent}) + negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion_latent}) + + latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + + control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent)) + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + control_video = vae.encode(control_video[:, :, :, :3]) + control_video_out[:, :, :control_video.shape[2]] = control_video + + # TODO: check if zero is better than none if none provided + positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out}) + negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out}) + + out_latent = {} + out_latent["samples"] = latent + return positive, negative, out_latent, frame_offset + + class WanSoundImageToVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -906,57 +967,44 @@ class WanSoundImageToVideo(io.ComfyNode): @classmethod def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None) -> io.NodeOutput: - latent_t = ((length - 1) // 4) + 1 - if audio_encoder_output is not None: - feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"]) - video_rate = 30 - fps = 16 - feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate) - audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=latent_t * 4, m=0, video_rate=video_rate) - audio_embed_bucket = audio_embed_bucket.unsqueeze(0) - if len(audio_embed_bucket.shape) == 3: - audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) - elif len(audio_embed_bucket.shape) == 4: - audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) + positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, ref_image=ref_image, audio_encoder_output=audio_encoder_output, + control_video=control_video, ref_motion=ref_motion) + return io.NodeOutput(positive, negative, out_latent) - positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket}) - negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket * 0.0}) - if ref_image is not None: - ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - ref_latent = vae.encode(ref_image[:, :, :, :3]) - positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) - negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) +class WanSoundImageToVideoExtend(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanSoundImageToVideoExtend", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Latent.Input("video_latent"), + io.AudioEncoderOutput.Input("audio_encoder_output", optional=True), + io.Image.Input("ref_image", optional=True), + io.Image.Input("control_video", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + is_experimental=True, + ) - if ref_motion is not None: - if ref_motion.shape[0] > 73: - ref_motion = ref_motion[-73:] - - ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - - if ref_motion.shape[0] < 73: - r = torch.ones([73, height, width, 3]) * 0.5 - r[-ref_motion.shape[0]:] = ref_motion - ref_motion = r - - ref_motion = vae.encode(ref_motion[:, :, :, :3]) - positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion}) - negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion}) - - latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - - control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent)) - if control_video is not None: - control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - control_video = vae.encode(control_video[:, :, :, :3]) - control_video_out[:, :, :control_video.shape[2]] = control_video - - # TODO: check if zero is better than none if none provided - positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out}) - negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out}) - - out_latent = {} - out_latent["samples"] = latent + @classmethod + def execute(cls, positive, negative, vae, length, video_latent, ref_image=None, audio_encoder_output=None, control_video=None) -> io.NodeOutput: + video_latent = video_latent["samples"] + width = video_latent.shape[-1] * 8 + height = video_latent.shape[-2] * 8 + batch_size = video_latent.shape[0] + frame_offset = video_latent.shape[-3] * 4 + positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=frame_offset, ref_image=ref_image, audio_encoder_output=audio_encoder_output, + control_video=control_video, ref_motion=None, ref_motion_latent=video_latent) return io.NodeOutput(positive, negative, out_latent) @@ -1019,6 +1067,7 @@ class WanExtension(ComfyExtension): WanCameraImageToVideo, WanPhantomSubjectToVideo, WanSoundImageToVideo, + WanSoundImageToVideoExtend, Wan22ImageToVideoLatent, ]