From 78672d0ee6d20d8269f324474643e5cc00f1c348 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 27 Jul 2025 04:42:58 -0700 Subject: [PATCH 1/6] Small readme update. (#9071) --- README.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index a148623cd..8a15136aa 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Image Models - - SD1.x, SD2.x, + - SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)) - [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/) - [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/) - [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) @@ -84,9 +84,9 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2) - Asynchronous Queue system - Many optimizations: Only re-executes the parts of the workflow that changes between executions. -- Smart memory management: can automatically run models on GPUs with as low as 1GB vram. +- Smart memory management: can automatically run large models on GPUs with as low as 1GB vram with smart offloading. - Works even if you don't have a GPU with: ```--cpu``` (slow) -- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models. +- Can load ckpt and safetensors: All in one checkpoints or standalone diffusion models, VAEs and CLIP models. - Safe loading of ckpt, pt, pth, etc.. files. - Embeddings/Textual inversion - [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) @@ -98,7 +98,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models. - [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/) - [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) -- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/) - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) From e6d9f6274494c5ac96295deb1bea54de50189059 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sun, 27 Jul 2025 16:51:36 -0700 Subject: [PATCH 2/6] Add Moonvalley Marey V2V node with updated input validation (#9069) * [moonvalley] Update V2V node to match API specification - Add exact resolution validation for supported resolutions (1920x1080, 1080x1920, 1152x1152, 1536x1152, 1152x1536) - Change frame count validation from divisible by 32 to 16 - Add MP4 container format validation - Remove internal parameters (steps, guidance_scale) from V2V inference params - Update video duration handling to support only 5 seconds (auto-trim if longer) - Add motion_intensity parameter (0-100) for Motion Transfer control type - Add get_container_format() method to VideoInput classes * update negative prompt --- comfy_api/input/video_types.py | 13 ++ comfy_api/input_impl/video_types.py | 12 ++ comfy_api_nodes/nodes_moonvalley.py | 225 +++++++++++++++------------- 3 files changed, 145 insertions(+), 105 deletions(-) diff --git a/comfy_api/input/video_types.py b/comfy_api/input/video_types.py index bb936e0a4..5d95dc507 100644 --- a/comfy_api/input/video_types.py +++ b/comfy_api/input/video_types.py @@ -2,6 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import Optional, Union import io +import av from comfy_api.util import VideoContainer, VideoCodec, VideoComponents class VideoInput(ABC): @@ -70,3 +71,15 @@ class VideoInput(ABC): components = self.get_components() frame_count = components.images.shape[0] return float(frame_count / components.frame_rate) + + def get_container_format(self) -> str: + """ + Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). + + Returns: + Container format as string + """ + # Default implementation - subclasses should override for better performance + source = self.get_stream_source() + with av.open(source, mode="r") as container: + return container.format.name diff --git a/comfy_api/input_impl/video_types.py b/comfy_api/input_impl/video_types.py index 9ae818f4e..91e7c1bfa 100644 --- a/comfy_api/input_impl/video_types.py +++ b/comfy_api/input_impl/video_types.py @@ -121,6 +121,18 @@ class VideoFromFile(VideoInput): raise ValueError(f"Could not determine duration for file '{self.__file}'") + def get_container_format(self) -> str: + """ + Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). + + Returns: + Container format as string + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + with av.open(self.__file, mode='r') as container: + return container.format.name + def get_components_internal(self, container: InputContainer) -> VideoComponents: # Get video frames frames = [] diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 057021efa..789fcef02 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -5,7 +5,6 @@ import torch from comfy_api_nodes.util.validation_utils import ( get_image_dimensions, validate_image_dimensions, - validate_video_dimensions, ) @@ -176,54 +175,76 @@ def validate_input_image( ) -def validate_input_video( - video: VideoInput, num_frames_out: int, with_frame_conditioning: bool = False -): +def validate_video_to_video_input(video: VideoInput) -> VideoInput: + """ + Validates and processes video input for Moonvalley Video-to-Video generation. + + Args: + video: Input video to validate + + Returns: + Validated and potentially trimmed video + + Raises: + ValueError: If video doesn't meet requirements + MoonvalleyApiError: If video duration is too short + """ + width, height = _get_video_dimensions(video) + _validate_video_dimensions(width, height) + _validate_container_format(video) + + return _validate_and_trim_duration(video) + + +def _get_video_dimensions(video: VideoInput) -> tuple[int, int]: + """Extracts video dimensions with error handling.""" try: - width, height = video.get_dimensions() + return video.get_dimensions() except Exception as e: logging.error("Error getting dimensions of video: %s", e) raise ValueError(f"Cannot get video dimensions: {e}") from e - validate_input_media(width, height, with_frame_conditioning) - validate_video_dimensions( - video, - min_width=MIN_VID_WIDTH, - min_height=MIN_VID_HEIGHT, - max_width=MAX_VID_WIDTH, - max_height=MAX_VID_HEIGHT, - ) - trimmed_video = validate_input_video_length(video, num_frames_out) - return trimmed_video +def _validate_video_dimensions(width: int, height: int) -> None: + """Validates video dimensions meet Moonvalley V2V requirements.""" + supported_resolutions = { + (1920, 1080), (1080, 1920), (1152, 1152), + (1536, 1152), (1152, 1536) + } + + if (width, height) not in supported_resolutions: + supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)]) + raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") -def validate_input_video_length(video: VideoInput, num_frames: int): +def _validate_container_format(video: VideoInput) -> None: + """Validates video container format is MP4.""" + container_format = video.get_container_format() + if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']: + raise ValueError(f"Only MP4 container format supported. Got: {container_format}") - if video.get_duration() > 60: - raise MoonvalleyApiError( - "Input Video lenth should be less than 1min. Please trim." - ) - if num_frames == 128: - if video.get_duration() < 5: - raise MoonvalleyApiError( - "Input Video length is less than 5s. Please use a video longer than or equal to 5s." - ) - if video.get_duration() > 5: - # trim video to 5s - video = trim_video(video, 5) - if num_frames == 256: - if video.get_duration() < 10: - raise MoonvalleyApiError( - "Input Video length is less than 10s. Please use a video longer than or equal to 10s." - ) - if video.get_duration() > 10: - # trim video to 10s - video = trim_video(video, 10) +def _validate_and_trim_duration(video: VideoInput) -> VideoInput: + """Validates video duration and trims to 5 seconds if needed.""" + duration = video.get_duration() + _validate_minimum_duration(duration) + return _trim_if_too_long(video, duration) + + +def _validate_minimum_duration(duration: float) -> None: + """Ensures video is at least 5 seconds long.""" + if duration < 5: + raise MoonvalleyApiError("Input video must be at least 5 seconds long.") + + +def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: + """Trims video to 5 seconds if longer.""" + if duration > 5: + return trim_video(video, 5) return video + def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: """ Returns a new VideoInput object trimmed from the beginning to the specified duration, @@ -278,15 +299,13 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels" ) - # Calculate target frame count that's divisible by 32 + # Calculate target frame count that's divisible by 16 fps = input_container.streams.video[0].average_rate estimated_frames = int(duration_sec * fps) - target_frames = ( - estimated_frames // 32 - ) * 32 # Round down to nearest multiple of 32 + target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16 if target_frames == 0: - raise ValueError("Video too short: need at least 32 frames for Moonvalley") + raise ValueError("Video too short: need at least 16 frames for Moonvalley") frame_count = 0 audio_frame_count = 0 @@ -353,8 +372,8 @@ class BaseMoonvalleyVideoNode: "16:9 (1920 x 1080)": {"width": 1920, "height": 1080}, "9:16 (1080 x 1920)": {"width": 1080, "height": 1920}, "1:1 (1152 x 1152)": {"width": 1152, "height": 1152}, - "4:3 (1440 x 1080)": {"width": 1440, "height": 1080}, - "3:4 (1080 x 1440)": {"width": 1080, "height": 1440}, + "4:3 (1536 x 1152)": {"width": 1536, "height": 1152}, + "3:4 (1152 x 1536)": {"width": 1152, "height": 1536}, "21:9 (2560 x 1080)": {"width": 2560, "height": 1080}, } if resolution in res_map: @@ -494,7 +513,6 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): image = kwargs.get("image", None) if image is None: raise MoonvalleyApiError("image is required") - total_frames = get_total_frames_from_length() validate_input_image(image, True) validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) @@ -505,7 +523,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): steps=kwargs.get("steps"), seed=kwargs.get("seed"), guidance_scale=kwargs.get("prompt_adherence"), - num_frames=total_frames, + num_frames=128, width=width_height.get("width"), height=width_height.get("height"), use_negative_prompts=True, @@ -549,39 +567,45 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): @classmethod def INPUT_TYPES(cls): - input_types = super().INPUT_TYPES() - for param in ["resolution", "image"]: - if param in input_types["required"]: - del input_types["required"][param] - if param in input_types["optional"]: - del input_types["optional"][param] - input_types["optional"] = { - "video": ( - IO.VIDEO, - { - "default": "", - "multiline": False, - "tooltip": "The reference video used to generate the output video. Input a 5s video for 128 frames and a 10s video for 256 frames. Longer videos will be trimmed automatically.", - }, - ), - "control_type": ( - ["Motion Transfer", "Pose Transfer"], - {"default": "Motion Transfer"}, - ), - "motion_intensity": ( - "INT", - { - "default": 100, - "step": 1, - "min": 0, - "max": 100, - "tooltip": "Only used if control_type is 'Motion Transfer'", - }, - ), + return { + "required": { + "prompt": model_field_to_node_input( + IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text", + multiline=True + ), + "negative_prompt": model_field_to_node_input( + IO.STRING, + MoonvalleyVideoToVideoInferenceParams, + "negative_prompt", + multiline=True, + default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts" + ), + "seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + "optional": { + "video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}), + "control_type": ( + ["Motion Transfer", "Pose Transfer"], + {"default": "Motion Transfer"}, + ), + "motion_intensity": ( + "INT", + { + "default": 100, + "step": 1, + "min": 0, + "max": 100, + "tooltip": "Only used if control_type is 'Motion Transfer'", + }, + ) + } } - return input_types - RETURN_TYPES = ("VIDEO",) RETURN_NAMES = ("video",) @@ -589,15 +613,13 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs ): video = kwargs.get("video") - num_frames = get_total_frames_from_length() if not video: raise MoonvalleyApiError("video is required") - """Validate video input""" video_url = "" if video: - validated_video = validate_input_video(video, num_frames, False) + validated_video = validate_video_to_video_input(video) video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs) control_type = kwargs.get("control_type") @@ -605,12 +627,16 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): """Validate prompts and inference input""" validate_prompts(prompt, negative_prompt) - inference_params = MoonvalleyVideoToVideoInferenceParams( + + # Only include motion_intensity for Motion Transfer + control_params = {} + if control_type == "Motion Transfer" and motion_intensity is not None: + control_params['motion_intensity'] = motion_intensity + + inference_params=MoonvalleyVideoToVideoInferenceParams( negative_prompt=negative_prompt, - steps=kwargs.get("steps"), seed=kwargs.get("seed"), - guidance_scale=kwargs.get("prompt_adherence"), - control_params={"motion_intensity": motion_intensity}, + control_params=control_params ) control = self.parseControlParameter(control_type) @@ -667,17 +693,16 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode): ): validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) width_height = self.parseWidthHeightFromRes(kwargs.get("resolution")) - num_frames = get_total_frames_from_length() - inference_params = MoonvalleyTextToVideoInferenceParams( - negative_prompt=negative_prompt, - steps=kwargs.get("steps"), - seed=kwargs.get("seed"), - guidance_scale=kwargs.get("prompt_adherence"), - num_frames=num_frames, - width=width_height.get("width"), - height=width_height.get("height"), - ) + inference_params=MoonvalleyTextToVideoInferenceParams( + negative_prompt=negative_prompt, + steps=kwargs.get("steps"), + seed=kwargs.get("seed"), + guidance_scale=kwargs.get("prompt_adherence"), + num_frames=128, + width=width_height.get("width"), + height=width_height.get("height"), + ) request = MoonvalleyTextToVideoRequest( prompt_text=prompt, inference_params=inference_params ) @@ -707,22 +732,12 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode): NODE_CLASS_MAPPINGS = { "MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode, "MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode, - # "MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode, + "MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode, } NODE_DISPLAY_NAME_MAPPINGS = { "MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video", "MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video", - # "MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video", + "MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video", } - - -def get_total_frames_from_length(length="5s"): - # if length == '5s': - # return 128 - # elif length == '10s': - # return 256 - return 128 - # else: - # raise MoonvalleyApiError("length is required") From d0210fe2e5df25b329926e20e3be32451fd5b841 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Mon, 28 Jul 2025 19:55:02 +0800 Subject: [PATCH 3/6] Update template to 0.1.41 (#9079) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 33a59b4be..14a085a2c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.23.4 -comfyui-workflow-templates==0.1.40 +comfyui-workflow-templates==0.1.41 comfyui-embedded-docs==0.2.4 torch torchsde From a88788dce6b0d7b5e2876c7cd0121b45e80f4ad8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 28 Jul 2025 05:00:23 -0700 Subject: [PATCH 4/6] Wan 2.2 support. (#9080) --- comfy/latent_formats.py | 76 ++++ comfy/ldm/wan/model.py | 16 +- comfy/ldm/wan/vae2_2.py | 726 ++++++++++++++++++++++++++++++++++++++ comfy/model_base.py | 30 +- comfy/model_detection.py | 2 + comfy/sd.py | 36 +- comfy/supported_models.py | 15 +- comfy_extras/nodes_wan.py | 44 +++ 8 files changed, 926 insertions(+), 19 deletions(-) create mode 100644 comfy/ldm/wan/vae2_2.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 82d9f9bb8..caf4991fc 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -457,6 +457,82 @@ class Wan21(LatentFormat): latents_std = self.latents_std.to(latent.device, latent.dtype) return latent * latents_std / self.scale_factor + latents_mean +class Wan22(Wan21): + latent_channels = 48 + latent_dimensions = 3 + + latent_rgb_factors = [ + [ 0.0119, 0.0103, 0.0046], + [-0.1062, -0.0504, 0.0165], + [ 0.0140, 0.0409, 0.0491], + [-0.0813, -0.0677, 0.0607], + [ 0.0656, 0.0851, 0.0808], + [ 0.0264, 0.0463, 0.0912], + [ 0.0295, 0.0326, 0.0590], + [-0.0244, -0.0270, 0.0025], + [ 0.0443, -0.0102, 0.0288], + [-0.0465, -0.0090, -0.0205], + [ 0.0359, 0.0236, 0.0082], + [-0.0776, 0.0854, 0.1048], + [ 0.0564, 0.0264, 0.0561], + [ 0.0006, 0.0594, 0.0418], + [-0.0319, -0.0542, -0.0637], + [-0.0268, 0.0024, 0.0260], + [ 0.0539, 0.0265, 0.0358], + [-0.0359, -0.0312, -0.0287], + [-0.0285, -0.1032, -0.1237], + [ 0.1041, 0.0537, 0.0622], + [-0.0086, -0.0374, -0.0051], + [ 0.0390, 0.0670, 0.2863], + [ 0.0069, 0.0144, 0.0082], + [ 0.0006, -0.0167, 0.0079], + [ 0.0313, -0.0574, -0.0232], + [-0.1454, -0.0902, -0.0481], + [ 0.0714, 0.0827, 0.0447], + [-0.0304, -0.0574, -0.0196], + [ 0.0401, 0.0384, 0.0204], + [-0.0758, -0.0297, -0.0014], + [ 0.0568, 0.1307, 0.1372], + [-0.0055, -0.0310, -0.0380], + [ 0.0239, -0.0305, 0.0325], + [-0.0663, -0.0673, -0.0140], + [-0.0416, -0.0047, -0.0023], + [ 0.0166, 0.0112, -0.0093], + [-0.0211, 0.0011, 0.0331], + [ 0.1833, 0.1466, 0.2250], + [-0.0368, 0.0370, 0.0295], + [-0.3441, -0.3543, -0.2008], + [-0.0479, -0.0489, -0.0420], + [-0.0660, -0.0153, 0.0800], + [-0.0101, 0.0068, 0.0156], + [-0.0690, -0.0452, -0.0927], + [-0.0145, 0.0041, 0.0015], + [ 0.0421, 0.0451, 0.0373], + [ 0.0504, -0.0483, -0.0356], + [-0.0837, 0.0168, 0.0055] + ] + + latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388] + + def __init__(self): + self.scale_factor = 1.0 + self.latents_mean = torch.tensor([ + -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, + -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, + -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, + -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230, + -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748, + 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667, + ]).view(1, self.latent_channels, 1, 1, 1) + self.latents_std = torch.tensor([ + 0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013, + 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, + 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, + 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, + 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, + 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744 + ]).view(1, self.latent_channels, 1, 1, 1) + class Hunyuan3Dv2(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 1d6edb354..b9e47e9f7 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -201,8 +201,10 @@ class WanAttentionBlock(nn.Module): freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ # assert e.dtype == torch.float32 - - e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) + if e.ndim < 4: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) + else: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2) # assert e[0].dtype == torch.float32 # self-attention @@ -325,7 +327,10 @@ class Head(nn.Module): e(Tensor): Shape [B, C] """ # assert e.dtype == torch.float32 - e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1) + if e.ndim < 3: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1) + else: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2) x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) return x @@ -506,8 +511,9 @@ class WanModel(torch.nn.Module): # time embeddings e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) - e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) # context context = self.text_embedding(context) diff --git a/comfy/ldm/wan/vae2_2.py b/comfy/ldm/wan/vae2_2.py new file mode 100644 index 000000000..c2c150e10 --- /dev/null +++ b/comfy/ldm/wan/vae2_2.py @@ -0,0 +1,726 @@ +# original version: https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/vae2_2.py +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from .vae import AttentionBlock, CausalConv3d, RMS_norm + +import comfy.ops +ops = comfy.ops.disable_weight_init + +CACHE_T = 2 + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ( + "none", + "upsample2d", + "upsample3d", + "downsample2d", + "downsample3d", + ) + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + ops.Conv2d(dim, dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + ops.Conv2d(dim, dim, 3, padding=1), + # ops.Conv2d(dim, dim//2, 3, padding=1) + ) + self.time_conv = CausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + ops.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + ops.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and + feat_cache[idx] != "Rep"): + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and + feat_cache[idx] == "Rep"): + cache_x = torch.cat( + [ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2, + ) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.resample(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), + nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), + nn.SiLU(), + nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1), + ) + self.shortcut = ( + CausalConv3d(in_dim, out_dim, 1) + if in_dim != out_dim else nn.Identity()) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +def patchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size, + ) + return x + + +class AvgDown3D(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1:, :, :] + return x + + +class Down_ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim, + dropout, + mult, + temperal_downsample=False, + down_flag=False): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + downsamples = [] + for _ in range(mult): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + downsamples.append(Resample(out_dim, mode=mode)) + + self.downsamples = nn.Sequential(*downsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for module in self.downsamples: + x = module(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + + +class Up_ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim, + dropout, + mult, + temperal_upsample=False, + up_flag=False): + super().__init__() + # Shortcut path with upsample + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2 if up_flag else 1, + ) + else: + self.avg_shortcut = None + + # Main path with residual blocks and upsample + upsamples = [] + for _ in range(mult): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final upsample block + if up_flag: + mode = "upsample3d" if temperal_upsample else "upsample2d" + upsamples.append(Resample(out_dim, mode=mode)) + + self.upsamples = nn.Sequential(*upsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + x_main = x.clone() + for module in self.upsamples: + x_main = module(x_main, feat_cache, feat_idx) + if self.avg_shortcut is not None: + x_shortcut = self.avg_shortcut(x, first_chunk) + return x_main + x_shortcut + else: + return x_main + + +class Encoder3d(nn.Module): + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(12, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_down_flag = ( + temperal_downsample[i] + if i < len(temperal_downsample) else False) + downsamples.append( + Down_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks, + temperal_downsample=t_down_flag, + down_flag=i != len(dim_mult) - 1, + )) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + + return x + + +class Decoder3d(nn.Module): + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout), + ) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_up_flag = temperal_upsample[i] if i < len( + temperal_upsample) else False + upsamples.append( + Up_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks + 1, + temperal_upsample=t_up_flag, + up_flag=i != len(dim_mult) - 1, + )) + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, 12, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, first_chunk) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE(nn.Module): + + def __init__( + self, + dim=160, + dec_dim=256, + z_dim=16, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d( + dim, + z_dim * 2, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_downsample, + dropout, + ) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d( + dec_dim, + z_dim, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_upsample, + dropout, + ) + + def encode(self, x): + self.clear_cache() + x = patchify(x, patch_size=2) + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + self.clear_cache() + return mu + + def decode(self, z): + self.clear_cache() + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + first_chunk=True, + ) + else: + out_ = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + ) + out = torch.cat([out, out_], 2) + out = unpatchify(out, patch_size=2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num diff --git a/comfy/model_base.py b/comfy/model_base.py index 4392355ea..d019b991a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1097,8 +1097,9 @@ class WAN21(BaseModel): image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16]) image = utils.resize_to_batch_size(image, noise.shape[0]) - if not self.image_to_video or extra_channels == image.shape[1]: - return image + if extra_channels != image.shape[1] + 4: + if not self.image_to_video or extra_channels == image.shape[1]: + return image if image.shape[1] > (extra_channels - 4): image = image[:, :(extra_channels - 4)] @@ -1182,6 +1183,31 @@ class WAN21_Camera(WAN21): out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions) return out +class WAN22(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if denoise_mask is not None: + out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask) + return out + + def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): + if denoise_mask is None: + return timestep + temp_ts = (torch.mean(denoise_mask[:, :, :, ::2, ::2], dim=1, keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1) + return temp_ts + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + return latent_image + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 18232ade3..9fc1f42de 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -346,7 +346,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config = {} dit_config["image_model"] = "wan2.1" dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1] + out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4 dit_config["dim"] = dim + dit_config["out_dim"] = out_dim dit_config["num_heads"] = dim // 128 dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0] dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') diff --git a/comfy/sd.py b/comfy/sd.py index 8081b167c..e0498e585 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae +import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline import yaml @@ -420,17 +421,30 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.working_dtypes = [torch.bfloat16, torch.float32] elif "decoder.middle.0.residual.0.gamma" in sd: - self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) - self.upscale_index_formula = (4, 8, 8) - self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) - self.downscale_index_formula = (4, 8, 8) - self.latent_dim = 3 - self.latent_channels = 16 - ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} - self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) - self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] - self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) + if "decoder.upsamples.0.upsamples.0.residual.2.weight" in sd: # Wan 2.2 VAE + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) + self.upscale_index_formula = (4, 16, 16) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) + self.downscale_index_formula = (4, 16, 16) + self.latent_dim = 3 + self.latent_channels = 48 + ddconfig = {"dim": 160, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} + self.first_stage_model = comfy.ldm.wan.vae2_2.WanVAE(**ddconfig) + self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] + self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype) + else: # Wan 2.1 VAE + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.latent_dim = 3 + self.latent_channels = 16 + ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} + self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) + self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] + self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd: self.latent_dim = 1 ln_post = "geo_decoder.ln_post.weight" in sd diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2ca3857f7..8f3f4652d 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1059,6 +1059,19 @@ class WAN21_Vace(WAN21_T2V): out = model_base.WAN21_Vace(self, image_to_video=False, device=device) return out +class WAN22_T2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "t2v", + "out_dim": 48, + } + + latent_format = latent_formats.Wan22 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN22(self, image_to_video=True, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1217,6 +1230,6 @@ class Omnigen2(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index d71908f31..0b92c68ac 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -685,6 +685,49 @@ class WanTrackToVideo: out_latent["samples"] = latent return (positive, negative, out_latent) + +class Wan22ImageToVideoLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": {"vae": ("VAE", ), + "width": ("INT", {"default": 1280, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}), + "height": ("INT", {"default": 704, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}), + "length": ("INT", {"default": 49, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"start_image": ("IMAGE", ), + }} + + + RETURN_TYPES = ("LATENT",) + FUNCTION = "encode" + + CATEGORY = "conditioning/inpaint" + + def encode(self, vae, width, height, length, batch_size, start_image=None): + latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + + if start_image is None: + out_latent = {} + out_latent["samples"] = latent + return (out_latent,) + + mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + latent_temp = vae.encode(start_image) + latent[:, :, :latent_temp.shape[-3]] = latent_temp + mask[:, :, :latent_temp.shape[-3]] *= 0.0 + + out_latent = {} + latent_format = comfy.latent_formats.Wan22() + latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask) + out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) + out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) + return (out_latent,) + + NODE_CLASS_MAPPINGS = { "WanTrackToVideo": WanTrackToVideo, "WanImageToVideo": WanImageToVideo, @@ -695,4 +738,5 @@ NODE_CLASS_MAPPINGS = { "TrimVideoLatent": TrimVideoLatent, "WanCameraImageToVideo": WanCameraImageToVideo, "WanPhantomSubjectToVideo": WanPhantomSubjectToVideo, + "Wan22ImageToVideoLatent": Wan22ImageToVideoLatent, } From 9f1388c0a38b9b6ebde0cdde904d94d709d3ca82 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 28 Jul 2025 05:01:53 -0700 Subject: [PATCH 5/6] Add wan2.2 to readme. (#9081) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 8a15136aa..befc4c006 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) + - [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/) - Audio Models - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/) From 5d4cc3ba1b412b9acacd37fd23d59e0e1654f83c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 28 Jul 2025 08:04:04 -0400 Subject: [PATCH 6/6] ComfyUI 0.3.46 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 180ecaf8a..315710dd2 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.45" +__version__ = "0.3.46" diff --git a/pyproject.toml b/pyproject.toml index b1d6d9df6..59c4c70fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.45" +version = "0.3.46" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9"