From 4293e4da214f77a3fde97c15f0691307e61bc18d Mon Sep 17 00:00:00 2001 From: Eugene Fairley Date: Thu, 24 Jul 2025 17:59:19 -0700 Subject: [PATCH] Add WAN ATI support (#8874) * Add WAN ATI support * Fixes * Fix length * Remove extra functions * Fix * Fix * Ruff fix * Remove torch.no_grad * Add batch trajectory logic * Scale inputs before and after motion patch * Batch image/trajectory * Ruff fix * Clean up --- comfy/utils.py | 20 +++ comfy_extras/nodes_wan.py | 305 +++++++++++++++++++++++++++++++++++++- 2 files changed, 324 insertions(+), 1 deletion(-) diff --git a/comfy/utils.py b/comfy/utils.py index 9c076a0e0..fab28cf08 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -698,6 +698,26 @@ def resize_to_batch_size(tensor, batch_size): return output +def resize_list_to_batch_size(l, batch_size): + in_batch_size = len(l) + if in_batch_size == batch_size or in_batch_size == 0: + return l + + if batch_size <= 1: + return l[:batch_size] + + output = [] + if batch_size < in_batch_size: + scale = (in_batch_size - 1) / (batch_size - 1) + for i in range(batch_size): + output.append(l[min(round(i * scale), in_batch_size - 1)]) + else: + scale = in_batch_size / batch_size + for i in range(batch_size): + output.append(l[min(math.floor((i + 0.5) * scale), in_batch_size - 1)]) + + return output + def convert_sd_to(state_dict, dtype): keys = list(state_dict.keys()) for k in keys: diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index d6097a104..d71908f31 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1,3 +1,4 @@ +import math import nodes import node_helpers import torch @@ -5,7 +6,9 @@ import comfy.model_management import comfy.utils import comfy.latent_formats import comfy.clip_vision - +import json +import numpy as np +from typing import Tuple class WanImageToVideo: @classmethod @@ -383,7 +386,307 @@ class WanPhantomSubjectToVideo: out_latent["samples"] = latent return (positive, cond2, negative, out_latent) +def parse_json_tracks(tracks): + """Parse JSON track data into a standardized format""" + tracks_data = [] + try: + # If tracks is a string, try to parse it as JSON + if isinstance(tracks, str): + parsed = json.loads(tracks.replace("'", '"')) + tracks_data.extend(parsed) + else: + # If tracks is a list of strings, parse each one + for track_str in tracks: + parsed = json.loads(track_str.replace("'", '"')) + tracks_data.append(parsed) + + # Check if we have a single track (dict with x,y) or a list of tracks + if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]: + # Single track detected, wrap it in a list + tracks_data = [tracks_data] + elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]: + # Already a list of tracks, nothing to do + pass + else: + # Unexpected format + pass + + except json.JSONDecodeError: + tracks_data = [] + return tracks_data + +def process_tracks(tracks_np: np.ndarray, frame_size: Tuple[int, int], num_frames, quant_multi: int = 8, **kwargs): + # tracks: shape [t, h, w, 3] => samples align with 24 fps, model trained with 16 fps. + # frame_size: tuple (W, H) + tracks = torch.from_numpy(tracks_np).float() + + if tracks.shape[1] == 121: + tracks = torch.permute(tracks, (1, 0, 2, 3)) + + tracks, visibles = tracks[..., :2], tracks[..., 2:3] + + short_edge = min(*frame_size) + + frame_center = torch.tensor([*frame_size]).type_as(tracks) / 2 + tracks = tracks - frame_center + + tracks = tracks / short_edge * 2 + + visibles = visibles * 2 - 1 + + trange = torch.linspace(-1, 1, tracks.shape[0]).view(-1, 1, 1, 1).expand(*visibles.shape) + + out_ = torch.cat([trange, tracks, visibles], dim=-1).view(121, -1, 4) + + out_0 = out_[:1] + + out_l = out_[1:] # 121 => 120 | 1 + a = 120 // math.gcd(120, num_frames) + b = num_frames // math.gcd(120, num_frames) + out_l = torch.repeat_interleave(out_l, b, dim=0)[1::a] # 120 => 120 * b => 120 * b / a == F + + final_result = torch.cat([out_0, out_l], dim=0) + + return final_result + +FIXED_LENGTH = 121 +def pad_pts(tr): + """Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating.""" + pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32) + n = pts.shape[0] + if n < FIXED_LENGTH: + pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32) + pts = np.vstack((pts, pad)) + else: + pts = pts[:FIXED_LENGTH] + return pts.reshape(FIXED_LENGTH, 1, 3) + +def ind_sel(target: torch.Tensor, ind: torch.Tensor, dim: int = 1): + """Index selection utility function""" + assert ( + len(ind.shape) > dim + ), "Index must have the target dim, but get dim: %d, ind shape: %s" % (dim, str(ind.shape)) + + target = target.expand( + *tuple( + [ind.shape[k] if target.shape[k] == 1 else -1 for k in range(dim)] + + [ + -1, + ] + * (len(target.shape) - dim) + ) + ) + + ind_pad = ind + + if len(target.shape) > dim + 1: + for _ in range(len(target.shape) - (dim + 1)): + ind_pad = ind_pad.unsqueeze(-1) + ind_pad = ind_pad.expand(*(-1,) * (dim + 1), *target.shape[(dim + 1) : :]) + + return torch.gather(target, dim=dim, index=ind_pad) + +def merge_final(vert_attr: torch.Tensor, weight: torch.Tensor, vert_assign: torch.Tensor): + """Merge vertex attributes with weights""" + target_dim = len(vert_assign.shape) - 1 + if len(vert_attr.shape) == 2: + assert vert_attr.shape[0] > vert_assign.max() + new_shape = [1] * target_dim + list(vert_attr.shape) + tensor = vert_attr.reshape(new_shape) + sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim) + else: + assert vert_attr.shape[1] > vert_assign.max() + new_shape = [vert_attr.shape[0]] + [1] * (target_dim - 1) + list(vert_attr.shape[1:]) + tensor = vert_attr.reshape(new_shape) + sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim) + + final_attr = torch.sum(sel_attr * weight.unsqueeze(-1), dim=-2) + return final_attr + + +def _patch_motion_single( + tracks: torch.FloatTensor, # (B, T, N, 4) + vid: torch.FloatTensor, # (C, T, H, W) + temperature: float, + vae_divide: tuple, + topk: int, +): + """Apply motion patching based on tracks""" + _, T, H, W = vid.shape + N = tracks.shape[2] + _, tracks_xy, visible = torch.split( + tracks, [1, 2, 1], dim=-1 + ) # (B, T, N, 2) | (B, T, N, 1) + tracks_n = tracks_xy / torch.tensor([W / min(H, W), H / min(H, W)], device=tracks_xy.device) + tracks_n = tracks_n.clamp(-1, 1) + visible = visible.clamp(0, 1) + + xx = torch.linspace(-W / min(H, W), W / min(H, W), W) + yy = torch.linspace(-H / min(H, W), H / min(H, W), H) + + grid = torch.stack(torch.meshgrid(yy, xx, indexing="ij")[::-1], dim=-1).to( + tracks_xy.device + ) + + tracks_pad = tracks_xy[:, 1:] + visible_pad = visible[:, 1:] + + visible_align = visible_pad.view(T - 1, 4, *visible_pad.shape[2:]).sum(1) + tracks_align = (tracks_pad * visible_pad).view(T - 1, 4, *tracks_pad.shape[2:]).sum( + 1 + ) / (visible_align + 1e-5) + dist_ = ( + (tracks_align[:, None, None] - grid[None, :, :, None]).pow(2).sum(-1) + ) # T, H, W, N + weight = torch.exp(-dist_ * temperature) * visible_align.clamp(0, 1).view( + T - 1, 1, 1, N + ) + vert_weight, vert_index = torch.topk( + weight, k=min(topk, weight.shape[-1]), dim=-1 + ) + + grid_mode = "bilinear" + point_feature = torch.nn.functional.grid_sample( + vid.permute(1, 0, 2, 3)[:1], + tracks_n[:, :1].type(vid.dtype), + mode=grid_mode, + padding_mode="zeros", + align_corners=False, + ) + point_feature = point_feature.squeeze(0).squeeze(1).permute(1, 0) # N, C=16 + + out_feature = merge_final(point_feature, vert_weight, vert_index).permute(3, 0, 1, 2) # T - 1, H, W, C => C, T - 1, H, W + out_weight = vert_weight.sum(-1) # T - 1, H, W + + # out feature -> already soft weighted + mix_feature = out_feature + vid[:, 1:] * (1 - out_weight.clamp(0, 1)) + + out_feature_full = torch.cat([vid[:, :1], mix_feature], dim=1) # C, T, H, W + out_mask_full = torch.cat([torch.ones_like(out_weight[:1]), out_weight], dim=0) # T, H, W + + return out_mask_full[None].expand(vae_divide[0], -1, -1, -1), out_feature_full + + +def patch_motion( + tracks: torch.FloatTensor, # (B, TB, T, N, 4) + vid: torch.FloatTensor, # (C, T, H, W) + temperature: float = 220.0, + vae_divide: tuple = (4, 16), + topk: int = 2, +): + B = len(tracks) + + # Process each batch separately + out_masks = [] + out_features = [] + + for b in range(B): + mask, feature = _patch_motion_single( + tracks[b], # (T, N, 4) + vid[b], # (C, T, H, W) + temperature, + vae_divide, + topk + ) + out_masks.append(mask) + out_features.append(feature) + + # Stack results: (B, C, T, H, W) + out_mask_full = torch.stack(out_masks, dim=0) + out_feature_full = torch.stack(out_features, dim=0) + + return out_mask_full, out_feature_full + +class WanTrackToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "tracks": ("STRING", {"multiline": True, "default": "[]"}), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}), + "topk": ("INT", {"default": 2, "min": 1, "max": 10}), + "start_image": ("IMAGE", ), + }, + "optional": { + "clip_vision_output": ("CLIP_VISION_OUTPUT", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, tracks, width, height, length, batch_size, + temperature, topk, start_image=None, clip_vision_output=None): + + tracks_data = parse_json_tracks(tracks) + + if not tracks_data: + return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output) + + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], + device=comfy.model_management.intermediate_device()) + + if isinstance(tracks_data[0][0], dict): + tracks_data = [tracks_data] + + processed_tracks = [] + for batch in tracks_data: + arrs = [] + for track in batch: + pts = pad_pts(track) + arrs.append(pts) + + tracks_np = np.stack(arrs, axis=0) + processed_tracks.append(process_tracks(tracks_np, (width, height), length - 1).unsqueeze(0)) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:batch_size].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + videos = torch.ones((start_image.shape[0], length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5 + for i in range(start_image.shape[0]): + videos[i, 0] = start_image[i] + + latent_videos = [] + videos = comfy.utils.resize_to_batch_size(videos, batch_size) + for i in range(batch_size): + latent_videos += [vae.encode(videos[i, :, :, :, :3])] + y = torch.cat(latent_videos, dim=0) + + # Scale latent since patch_motion is non-linear + y = comfy.latent_formats.Wan21().process_in(y) + + processed_tracks = comfy.utils.resize_list_to_batch_size(processed_tracks, batch_size) + res = patch_motion( + processed_tracks, y, temperature=temperature, topk=topk, vae_divide=(4, 16) + ) + + mask, concat_latent_image = res + concat_latent_image = comfy.latent_formats.Wan21().process_out(concat_latent_image) + mask = -mask + 1.0 # Invert mask to match expected format + positive = node_helpers.conditioning_set_values(positive, + {"concat_mask": mask, + "concat_latent_image": concat_latent_image}) + negative = node_helpers.conditioning_set_values(negative, + {"concat_mask": mask, + "concat_latent_image": concat_latent_image}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + NODE_CLASS_MAPPINGS = { + "WanTrackToVideo": WanTrackToVideo, "WanImageToVideo": WanImageToVideo, "WanFunControlToVideo": WanFunControlToVideo, "WanFunInpaintToVideo": WanFunInpaintToVideo,