Merge branch 'master' into worksplit-multigpu

This commit is contained in:
Jedrzej Kosinski 2025-07-28 06:17:24 -07:00
commit 5d5024296d
15 changed files with 1078 additions and 131 deletions

View File

@ -55,7 +55,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Image Models
- SD1.x, SD2.x,
- SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
@ -77,6 +77,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
- Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
@ -84,9 +85,9 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
- Smart memory management: can automatically run large models on GPUs with as low as 1GB vram with smart offloading.
- Works even if you don't have a GPU with: ```--cpu``` (slow)
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
- Can load ckpt and safetensors: All in one checkpoints or standalone diffusion models, VAEs and CLIP models.
- Safe loading of ckpt, pt, pth, etc.. files.
- Embeddings/Textual inversion
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
@ -98,7 +99,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)

View File

@ -457,6 +457,82 @@ class Wan21(LatentFormat):
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class Wan22(Wan21):
latent_channels = 48
latent_dimensions = 3
latent_rgb_factors = [
[ 0.0119, 0.0103, 0.0046],
[-0.1062, -0.0504, 0.0165],
[ 0.0140, 0.0409, 0.0491],
[-0.0813, -0.0677, 0.0607],
[ 0.0656, 0.0851, 0.0808],
[ 0.0264, 0.0463, 0.0912],
[ 0.0295, 0.0326, 0.0590],
[-0.0244, -0.0270, 0.0025],
[ 0.0443, -0.0102, 0.0288],
[-0.0465, -0.0090, -0.0205],
[ 0.0359, 0.0236, 0.0082],
[-0.0776, 0.0854, 0.1048],
[ 0.0564, 0.0264, 0.0561],
[ 0.0006, 0.0594, 0.0418],
[-0.0319, -0.0542, -0.0637],
[-0.0268, 0.0024, 0.0260],
[ 0.0539, 0.0265, 0.0358],
[-0.0359, -0.0312, -0.0287],
[-0.0285, -0.1032, -0.1237],
[ 0.1041, 0.0537, 0.0622],
[-0.0086, -0.0374, -0.0051],
[ 0.0390, 0.0670, 0.2863],
[ 0.0069, 0.0144, 0.0082],
[ 0.0006, -0.0167, 0.0079],
[ 0.0313, -0.0574, -0.0232],
[-0.1454, -0.0902, -0.0481],
[ 0.0714, 0.0827, 0.0447],
[-0.0304, -0.0574, -0.0196],
[ 0.0401, 0.0384, 0.0204],
[-0.0758, -0.0297, -0.0014],
[ 0.0568, 0.1307, 0.1372],
[-0.0055, -0.0310, -0.0380],
[ 0.0239, -0.0305, 0.0325],
[-0.0663, -0.0673, -0.0140],
[-0.0416, -0.0047, -0.0023],
[ 0.0166, 0.0112, -0.0093],
[-0.0211, 0.0011, 0.0331],
[ 0.1833, 0.1466, 0.2250],
[-0.0368, 0.0370, 0.0295],
[-0.3441, -0.3543, -0.2008],
[-0.0479, -0.0489, -0.0420],
[-0.0660, -0.0153, 0.0800],
[-0.0101, 0.0068, 0.0156],
[-0.0690, -0.0452, -0.0927],
[-0.0145, 0.0041, 0.0015],
[ 0.0421, 0.0451, 0.0373],
[ 0.0504, -0.0483, -0.0356],
[-0.0837, 0.0168, 0.0055]
]
latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388]
def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
]).view(1, self.latent_channels, 1, 1, 1)
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@ -201,8 +201,10 @@ class WanAttentionBlock(nn.Module):
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
# assert e.dtype == torch.float32
if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
# assert e[0].dtype == torch.float32
# self-attention
@ -325,7 +327,10 @@ class Head(nn.Module):
e(Tensor): Shape [B, C]
"""
# assert e.dtype == torch.float32
if e.ndim < 3:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x
@ -506,8 +511,9 @@ class WanModel(torch.nn.Module):
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
# context
context = self.text_embedding(context)

726
comfy/ldm/wan/vae2_2.py Normal file
View File

@ -0,0 +1,726 @@
# original version: https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/vae2_2.py
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .vae import AttentionBlock, CausalConv3d, RMS_norm
import comfy.ops
ops = comfy.ops.disable_weight_init
CACHE_T = 2
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in (
"none",
"upsample2d",
"upsample3d",
"downsample2d",
"downsample3d",
)
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
ops.Conv2d(dim, dim, 3, padding=1),
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
ops.Conv2d(dim, dim, 3, padding=1),
# ops.Conv2d(dim, dim//2, 3, padding=1)
)
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
feat_cache[idx] != "Rep"):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
feat_cache[idx] == "Rep"):
cache_x = torch.cat(
[
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2,
)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.resample(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False),
nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False),
nn.SiLU(),
nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1),
)
self.shortcut = (
CausalConv3d(in_dim, out_dim, 1)
if in_dim != out_dim else nn.Identity())
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
def patchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b c f (h q) (w r) -> b (c r q) f h w",
q=patch_size,
r=patch_size,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b (c r q) f h w -> b c f (h q) (w r)",
q=patch_size,
r=patch_size,
)
return x
class AvgDown3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert in_channels * self.factor % out_channels == 0
self.group_size = in_channels * self.factor // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
pad = (0, 0, 0, 0, pad_t, 0)
x = F.pad(x, pad)
B, C, T, H, W = x.shape
x = x.view(
B,
C,
T // self.factor_t,
self.factor_t,
H // self.factor_s,
self.factor_s,
W // self.factor_s,
self.factor_s,
)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(
B,
C * self.factor,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.view(
B,
self.out_channels,
self.group_size,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.mean(dim=2)
return x
class DupUp3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert out_channels * self.factor % in_channels == 0
self.repeats = out_channels * self.factor // in_channels
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = x.view(
x.size(0),
self.out_channels,
self.factor_t,
self.factor_s,
self.factor_s,
x.size(2),
x.size(3),
x.size(4),
)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(
x.size(0),
self.out_channels,
x.size(2) * self.factor_t,
x.size(4) * self.factor_s,
x.size(6) * self.factor_s,
)
if first_chunk:
x = x[:, :, self.factor_t - 1:, :, :]
return x
class Down_ResidualBlock(nn.Module):
def __init__(self,
in_dim,
out_dim,
dropout,
mult,
temperal_downsample=False,
down_flag=False):
super().__init__()
# Shortcut path with downsample
self.avg_shortcut = AvgDown3D(
in_dim,
out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
# Main path with residual blocks and downsample
downsamples = []
for _ in range(mult):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
# Add the final downsample block
if down_flag:
mode = "downsample3d" if temperal_downsample else "downsample2d"
downsamples.append(Resample(out_dim, mode=mode))
self.downsamples = nn.Sequential(*downsamples)
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for module in self.downsamples:
x = module(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy)
class Up_ResidualBlock(nn.Module):
def __init__(self,
in_dim,
out_dim,
dropout,
mult,
temperal_upsample=False,
up_flag=False):
super().__init__()
# Shortcut path with upsample
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim,
out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2 if up_flag else 1,
)
else:
self.avg_shortcut = None
# Main path with residual blocks and upsample
upsamples = []
for _ in range(mult):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
# Add the final upsample block
if up_flag:
mode = "upsample3d" if temperal_upsample else "upsample2d"
upsamples.append(Resample(out_dim, mode=mode))
self.upsamples = nn.Sequential(*upsamples)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
x_main = x.clone()
for module in self.upsamples:
x_main = module(x_main, feat_cache, feat_idx)
if self.avg_shortcut is not None:
x_shortcut = self.avg_shortcut(x, first_chunk)
return x_main + x_shortcut
else:
return x_main
class Encoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_down_flag = (
temperal_downsample[i]
if i < len(temperal_downsample) else False)
downsamples.append(
Down_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
dropout=dropout,
mult=num_res_blocks,
temperal_downsample=t_down_flag,
down_flag=i != len(dim_mult) - 1,
))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout),
AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout),
)
# # output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False),
nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout),
)
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_up_flag = temperal_upsample[i] if i < len(
temperal_upsample) else False
upsamples.append(
Up_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
dropout=dropout,
mult=num_res_blocks + 1,
temperal_upsample=t_up_flag,
up_flag=i != len(dim_mult) - 1,
))
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False),
nn.SiLU(),
CausalConv3d(out_dim, 12, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, first_chunk)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
count += 1
return count
class WanVAE(nn.Module):
def __init__(
self,
dim=160,
dec_dim=256,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(
dim,
z_dim * 2,
dim_mult,
num_res_blocks,
attn_scales,
self.temperal_downsample,
dropout,
)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(
dec_dim,
z_dim,
dim_mult,
num_res_blocks,
attn_scales,
self.temperal_upsample,
dropout,
)
def encode(self, x):
self.clear_cache()
x = patchify(x, patch_size=2)
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
self.clear_cache()
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
first_chunk=True,
)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
)
out = torch.cat([out, out_], 2)
out = unpatchify(out, patch_size=2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@ -1097,6 +1097,7 @@ class WAN21(BaseModel):
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
image = utils.resize_to_batch_size(image, noise.shape[0])
if extra_channels != image.shape[1] + 4:
if not self.image_to_video or extra_channels == image.shape[1]:
return image
@ -1182,6 +1183,31 @@ class WAN21_Camera(WAN21):
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
return out
class WAN22(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
if denoise_mask is None:
return timestep
temp_ts = (torch.mean(denoise_mask[:, :, :, ::2, ::2], dim=1, keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
return temp_ts
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)

View File

@ -346,7 +346,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config = {}
dit_config["image_model"] = "wan2.1"
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4
dit_config["dim"] = dim
dit_config["out_dim"] = out_dim
dit_config["num_heads"] = dim // 128
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')

View File

@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import yaml
@ -420,6 +421,19 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.middle.0.residual.0.gamma" in sd:
if "decoder.upsamples.0.upsamples.0.residual.2.weight" in sd: # Wan 2.2 VAE
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.latent_channels = 48
ddconfig = {"dim": 160, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae2_2.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
else: # Wan 2.1 VAE
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)

View File

@ -1059,6 +1059,19 @@ class WAN21_Vace(WAN21_T2V):
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
return out
class WAN22_T2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "t2v",
"out_dim": 48,
}
latent_format = latent_formats.Wan22
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN22(self, image_to_video=True, device=device)
return out
class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",
@ -1217,6 +1230,6 @@ class Omnigen2(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
models += [SVD_img2vid]

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Union
import io
import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
@ -70,3 +71,15 @@ class VideoInput(ABC):
components = self.get_components()
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)
def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
Returns:
Container format as string
"""
# Default implementation - subclasses should override for better performance
source = self.get_stream_source()
with av.open(source, mode="r") as container:
return container.format.name

View File

@ -121,6 +121,18 @@ class VideoFromFile(VideoInput):
raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
Returns:
Container format as string
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode='r') as container:
return container.format.name
def get_components_internal(self, container: InputContainer) -> VideoComponents:
# Get video frames
frames = []

View File

@ -5,7 +5,6 @@ import torch
from comfy_api_nodes.util.validation_utils import (
get_image_dimensions,
validate_image_dimensions,
validate_video_dimensions,
)
@ -176,54 +175,76 @@ def validate_input_image(
)
def validate_input_video(
video: VideoInput, num_frames_out: int, with_frame_conditioning: bool = False
):
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.
Args:
video: Input video to validate
Returns:
Validated and potentially trimmed video
Raises:
ValueError: If video doesn't meet requirements
MoonvalleyApiError: If video duration is too short
"""
width, height = _get_video_dimensions(video)
_validate_video_dimensions(width, height)
_validate_container_format(video)
return _validate_and_trim_duration(video)
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
"""Extracts video dimensions with error handling."""
try:
width, height = video.get_dimensions()
return video.get_dimensions()
except Exception as e:
logging.error("Error getting dimensions of video: %s", e)
raise ValueError(f"Cannot get video dimensions: {e}") from e
validate_input_media(width, height, with_frame_conditioning)
validate_video_dimensions(
video,
min_width=MIN_VID_WIDTH,
min_height=MIN_VID_HEIGHT,
max_width=MAX_VID_WIDTH,
max_height=MAX_VID_HEIGHT,
)
trimmed_video = validate_input_video_length(video, num_frames_out)
return trimmed_video
def _validate_video_dimensions(width: int, height: int) -> None:
"""Validates video dimensions meet Moonvalley V2V requirements."""
supported_resolutions = {
(1920, 1080), (1080, 1920), (1152, 1152),
(1536, 1152), (1152, 1536)
}
if (width, height) not in supported_resolutions:
supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)])
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
def validate_input_video_length(video: VideoInput, num_frames: int):
def _validate_container_format(video: VideoInput) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
if video.get_duration() > 60:
raise MoonvalleyApiError(
"Input Video lenth should be less than 1min. Please trim."
)
if num_frames == 128:
if video.get_duration() < 5:
raise MoonvalleyApiError(
"Input Video length is less than 5s. Please use a video longer than or equal to 5s."
)
if video.get_duration() > 5:
# trim video to 5s
video = trim_video(video, 5)
if num_frames == 256:
if video.get_duration() < 10:
raise MoonvalleyApiError(
"Input Video length is less than 10s. Please use a video longer than or equal to 10s."
)
if video.get_duration() > 10:
# trim video to 10s
video = trim_video(video, 10)
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
"""Validates video duration and trims to 5 seconds if needed."""
duration = video.get_duration()
_validate_minimum_duration(duration)
return _trim_if_too_long(video, duration)
def _validate_minimum_duration(duration: float) -> None:
"""Ensures video is at least 5 seconds long."""
if duration < 5:
raise MoonvalleyApiError("Input video must be at least 5 seconds long.")
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
"""Trims video to 5 seconds if longer."""
if duration > 5:
return trim_video(video, 5)
return video
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
"""
Returns a new VideoInput object trimmed from the beginning to the specified duration,
@ -278,15 +299,13 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
)
# Calculate target frame count that's divisible by 32
# Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate
estimated_frames = int(duration_sec * fps)
target_frames = (
estimated_frames // 32
) * 32 # Round down to nearest multiple of 32
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
if target_frames == 0:
raise ValueError("Video too short: need at least 32 frames for Moonvalley")
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
frame_count = 0
audio_frame_count = 0
@ -353,8 +372,8 @@ class BaseMoonvalleyVideoNode:
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
"9:16 (1080 x 1920)": {"width": 1080, "height": 1920},
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
"4:3 (1440 x 1080)": {"width": 1440, "height": 1080},
"3:4 (1080 x 1440)": {"width": 1080, "height": 1440},
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
}
if resolution in res_map:
@ -494,7 +513,6 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
image = kwargs.get("image", None)
if image is None:
raise MoonvalleyApiError("image is required")
total_frames = get_total_frames_from_length()
validate_input_image(image, True)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
@ -505,7 +523,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
num_frames=total_frames,
num_frames=128,
width=width_height.get("width"),
height=width_height.get("height"),
use_negative_prompts=True,
@ -549,21 +567,28 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
@classmethod
def INPUT_TYPES(cls):
input_types = super().INPUT_TYPES()
for param in ["resolution", "image"]:
if param in input_types["required"]:
del input_types["required"][param]
if param in input_types["optional"]:
del input_types["optional"][param]
input_types["optional"] = {
"video": (
IO.VIDEO,
{
"default": "",
"multiline": False,
"tooltip": "The reference video used to generate the output video. Input a 5s video for 128 frames and a 10s video for 256 frames. Longer videos will be trimmed automatically.",
},
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text",
multiline=True
),
"negative_prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyVideoToVideoInferenceParams,
"negative_prompt",
multiline=True,
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts"
),
"seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
"optional": {
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}),
"control_type": (
["Motion Transfer", "Pose Transfer"],
{"default": "Motion Transfer"},
@ -577,10 +602,9 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
"max": 100,
"tooltip": "Only used if control_type is 'Motion Transfer'",
},
),
)
}
}
return input_types
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
@ -589,15 +613,13 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
video = kwargs.get("video")
num_frames = get_total_frames_from_length()
if not video:
raise MoonvalleyApiError("video is required")
"""Validate video input"""
video_url = ""
if video:
validated_video = validate_input_video(video, num_frames, False)
validated_video = validate_video_to_video_input(video)
video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
control_type = kwargs.get("control_type")
@ -605,12 +627,16 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
"""Validate prompts and inference input"""
validate_prompts(prompt, negative_prompt)
# Only include motion_intensity for Motion Transfer
control_params = {}
if control_type == "Motion Transfer" and motion_intensity is not None:
control_params['motion_intensity'] = motion_intensity
inference_params=MoonvalleyVideoToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
control_params={"motion_intensity": motion_intensity},
control_params=control_params
)
control = self.parseControlParameter(control_type)
@ -667,14 +693,13 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
):
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
num_frames = get_total_frames_from_length()
inference_params=MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
num_frames=num_frames,
num_frames=128,
width=width_height.get("width"),
height=width_height.get("height"),
)
@ -707,22 +732,12 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
NODE_CLASS_MAPPINGS = {
"MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode,
"MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode,
# "MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode,
"MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video",
"MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video",
# "MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video",
"MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video",
}
def get_total_frames_from_length(length="5s"):
# if length == '5s':
# return 128
# elif length == '10s':
# return 256
return 128
# else:
# raise MoonvalleyApiError("length is required")

View File

@ -685,6 +685,49 @@ class WanTrackToVideo:
out_latent["samples"] = latent
return (positive, negative, out_latent)
class Wan22ImageToVideoLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": {"vae": ("VAE", ),
"width": ("INT", {"default": 1280, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 704, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 49, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"start_image": ("IMAGE", ),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, vae, width, height, length, batch_size, start_image=None):
latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
if start_image is None:
out_latent = {}
out_latent["samples"] = latent
return (out_latent,)
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
latent_temp = vae.encode(start_image)
latent[:, :, :latent_temp.shape[-3]] = latent_temp
mask[:, :, :latent_temp.shape[-3]] *= 0.0
out_latent = {}
latent_format = comfy.latent_formats.Wan22()
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return (out_latent,)
NODE_CLASS_MAPPINGS = {
"WanTrackToVideo": WanTrackToVideo,
"WanImageToVideo": WanImageToVideo,
@ -695,4 +738,5 @@ NODE_CLASS_MAPPINGS = {
"TrimVideoLatent": TrimVideoLatent,
"WanCameraImageToVideo": WanCameraImageToVideo,
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo,
"Wan22ImageToVideoLatent": Wan22ImageToVideoLatent,
}

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.45"
__version__ = "0.3.46"

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.45"
version = "0.3.46"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.23.4
comfyui-workflow-templates==0.1.40
comfyui-workflow-templates==0.1.41
comfyui-embedded-docs==0.2.4
torch
torchsde