diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index f80c83ba6..694a183f6 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -9,29 +9,35 @@ import comfy.clip_vision import json import numpy as np from typing import Tuple +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class WanImageToVideo: +class WanImageToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -51,32 +57,36 @@ class WanImageToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanFunControlToVideo: +class WanFunControlToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "control_video": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanFunControlToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.Image.Input("control_video", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) @@ -101,31 +111,34 @@ class WanFunControlToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class Wan22FunControlToVideo: +class Wan22FunControlToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"ref_image": ("IMAGE", ), - "control_video": ("IMAGE", ), - # "start_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="Wan22FunControlToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("ref_image", optional=True), + io.Image.Input("control_video", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) @@ -158,32 +171,36 @@ class Wan22FunControlToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanFirstLastFrameToVideo: +class WanFirstLastFrameToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ), - "clip_vision_end_image": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanFirstLastFrameToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_start_image", optional=True), + io.ClipVisionOutput.Input("clip_vision_end_image", optional=True), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -224,62 +241,70 @@ class WanFirstLastFrameToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanFunInpaintToVideo: +class WanFunInpaintToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanFunInpaintToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None) -> io.NodeOutput: flfv = WanFirstLastFrameToVideo() - return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) + return flfv.execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) -class WanVaceToVideo: +class WanVaceToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}), - }, - "optional": {"control_video": ("IMAGE", ), - "control_masks": ("MASK", ), - "reference_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanVaceToVideo", + category="conditioning/video_models", + is_experimental=True, + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("strength", default=1.0, min=0.0, max=1000.0, step=0.01), + io.Image.Input("control_video", optional=True), + io.Mask.Input("control_masks", optional=True), + io.Image.Input("reference_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + io.Int.Output(display_name="trim_latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT") - RETURN_NAMES = ("positive", "negative", "latent", "trim_latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - EXPERIMENTAL = True - - def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None) -> io.NodeOutput: latent_length = ((length - 1) // 4) + 1 if control_video is not None: control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -336,52 +361,59 @@ class WanVaceToVideo: latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device()) out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent, trim_latent) + return io.NodeOutput(positive, negative, out_latent, trim_latent) -class TrimVideoLatent: +class TrimVideoLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}), - }} + def define_schema(cls): + return io.Schema( + node_id="TrimVideoLatent", + category="latent/video", + is_experimental=True, + inputs=[ + io.Latent.Input("samples"), + io.Int.Input("trim_amount", default=0, min=0, max=99999), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/video" - - EXPERIMENTAL = True - - def op(self, samples, trim_amount): + @classmethod + def execute(cls, samples, trim_amount) -> io.NodeOutput: samples_out = samples.copy() s1 = samples["samples"] samples_out["samples"] = s1[:, :, trim_amount:] - return (samples_out,) + return io.NodeOutput(samples_out) -class WanCameraImageToVideo: +class WanCameraImageToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "camera_conditions": ("WAN_CAMERA_EMBEDDING", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanCameraImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.WanCameraEmbedding.Input("camera_conditions", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) @@ -404,29 +436,34 @@ class WanCameraImageToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanPhantomSubjectToVideo: +class WanPhantomSubjectToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"images": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanPhantomSubjectToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("images", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative_text"), + io.Conditioning.Output(display_name="negative_img_text"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, images): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, images) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) cond2 = negative if images is not None: @@ -442,7 +479,7 @@ class WanPhantomSubjectToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, cond2, negative, out_latent) + return io.NodeOutput(positive, cond2, negative, out_latent) def parse_json_tracks(tracks): """Parse JSON track data into a standardized format""" @@ -655,39 +692,41 @@ def patch_motion( return out_mask_full, out_feature_full -class WanTrackToVideo: +class WanTrackToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "tracks": ("STRING", {"multiline": True, "default": "[]"}), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}), - "topk": ("INT", {"default": 2, "min": 1, "max": 10}), - "start_image": ("IMAGE", ), - }, - "optional": { - "clip_vision_output": ("CLIP_VISION_OUTPUT", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanPhantomSubjectToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.String.Input("tracks", multiline=True, default="[]"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("temperature", default=220.0, min=1.0, max=1000.0, step=0.1), + io.Int.Input("topk", default=2, min=1, max=10), + io.Image.Input("start_image"), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, tracks, width, height, length, batch_size, - temperature, topk, start_image=None, clip_vision_output=None): + @classmethod + def execute(cls, positive, negative, vae, tracks, width, height, length, batch_size, + temperature, topk, start_image=None, clip_vision_output=None) -> io.NodeOutput: tracks_data = parse_json_tracks(tracks) if not tracks_data: - return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output) + return WanImageToVideo().execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) @@ -741,34 +780,36 @@ class WanTrackToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class Wan22ImageToVideoLatent: +class Wan22ImageToVideoLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"vae": ("VAE", ), - "width": ("INT", {"default": 1280, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}), - "height": ("INT", {"default": 704, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}), - "length": ("INT", {"default": 49, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"start_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="Wan22ImageToVideoLatent", + category="conditioning/inpaint", + inputs=[ + io.Vae.Input("vae"), + io.Int.Input("width", default=1280, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=704, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Latent.Output(), + ], + ) - - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" - - CATEGORY = "conditioning/inpaint" - - def encode(self, vae, width, height, length, batch_size, start_image=None): + @classmethod + def execute(cls, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput: latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) if start_image is None: out_latent = {} out_latent["samples"] = latent - return (out_latent,) + return io.NodeOutput(out_latent) mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) @@ -783,19 +824,25 @@ class Wan22ImageToVideoLatent: latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask) out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) - return (out_latent,) + return io.NodeOutput(out_latent) -NODE_CLASS_MAPPINGS = { - "WanTrackToVideo": WanTrackToVideo, - "WanImageToVideo": WanImageToVideo, - "WanFunControlToVideo": WanFunControlToVideo, - "Wan22FunControlToVideo": Wan22FunControlToVideo, - "WanFunInpaintToVideo": WanFunInpaintToVideo, - "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, - "WanVaceToVideo": WanVaceToVideo, - "TrimVideoLatent": TrimVideoLatent, - "WanCameraImageToVideo": WanCameraImageToVideo, - "WanPhantomSubjectToVideo": WanPhantomSubjectToVideo, - "Wan22ImageToVideoLatent": Wan22ImageToVideoLatent, -} +class WanExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + WanTrackToVideo, + WanImageToVideo, + WanFunControlToVideo, + Wan22FunControlToVideo, + WanFunInpaintToVideo, + WanFirstLastFrameToVideo, + WanVaceToVideo, + TrimVideoLatent, + WanCameraImageToVideo, + WanPhantomSubjectToVideo, + Wan22ImageToVideoLatent, + ] + +async def comfy_entrypoint() -> WanExtension: + return WanExtension()