Merge branch 'master' into v3-definition

This commit is contained in:
Jedrzej Kosinski 2025-07-29 12:49:16 -07:00
commit 930f8d9e6d
19 changed files with 1534 additions and 239 deletions

View File

@ -17,8 +17,7 @@ jobs:
- name: Check for Windows line endings (CRLF) - name: Check for Windows line endings (CRLF)
run: | run: |
# Get the list of changed files in the PR # Get the list of changed files in the PR
git merge origin/${{ github.base_ref }} --no-edit CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
CHANGED_FILES=$(git diff --name-only origin/${{ github.base_ref }}..HEAD)
# Flag to track if CRLF is found # Flag to track if CRLF is found
CRLF_FOUND=false CRLF_FOUND=false

View File

@ -55,7 +55,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
## Features ## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Image Models - Image Models
- SD1.x, SD2.x, - SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/) - [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/) - [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) - [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
@ -77,6 +77,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/) - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
- Audio Models - Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
@ -84,9 +85,9 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2) - [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
- Asynchronous Queue system - Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram. - Smart memory management: can automatically run large models on GPUs with as low as 1GB vram with smart offloading.
- Works even if you don't have a GPU with: ```--cpu``` (slow) - Works even if you don't have a GPU with: ```--cpu``` (slow)
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models. - Can load ckpt and safetensors: All in one checkpoints or standalone diffusion models, VAEs and CLIP models.
- Safe loading of ckpt, pt, pth, etc.. files. - Safe loading of ckpt, pt, pth, etc.. files.
- Embeddings/Textual inversion - Embeddings/Textual inversion
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) - [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
@ -98,7 +99,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models. - [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/) - [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) - [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/) - [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)

View File

@ -457,6 +457,82 @@ class Wan21(LatentFormat):
latents_std = self.latents_std.to(latent.device, latent.dtype) latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean return latent * latents_std / self.scale_factor + latents_mean
class Wan22(Wan21):
latent_channels = 48
latent_dimensions = 3
latent_rgb_factors = [
[ 0.0119, 0.0103, 0.0046],
[-0.1062, -0.0504, 0.0165],
[ 0.0140, 0.0409, 0.0491],
[-0.0813, -0.0677, 0.0607],
[ 0.0656, 0.0851, 0.0808],
[ 0.0264, 0.0463, 0.0912],
[ 0.0295, 0.0326, 0.0590],
[-0.0244, -0.0270, 0.0025],
[ 0.0443, -0.0102, 0.0288],
[-0.0465, -0.0090, -0.0205],
[ 0.0359, 0.0236, 0.0082],
[-0.0776, 0.0854, 0.1048],
[ 0.0564, 0.0264, 0.0561],
[ 0.0006, 0.0594, 0.0418],
[-0.0319, -0.0542, -0.0637],
[-0.0268, 0.0024, 0.0260],
[ 0.0539, 0.0265, 0.0358],
[-0.0359, -0.0312, -0.0287],
[-0.0285, -0.1032, -0.1237],
[ 0.1041, 0.0537, 0.0622],
[-0.0086, -0.0374, -0.0051],
[ 0.0390, 0.0670, 0.2863],
[ 0.0069, 0.0144, 0.0082],
[ 0.0006, -0.0167, 0.0079],
[ 0.0313, -0.0574, -0.0232],
[-0.1454, -0.0902, -0.0481],
[ 0.0714, 0.0827, 0.0447],
[-0.0304, -0.0574, -0.0196],
[ 0.0401, 0.0384, 0.0204],
[-0.0758, -0.0297, -0.0014],
[ 0.0568, 0.1307, 0.1372],
[-0.0055, -0.0310, -0.0380],
[ 0.0239, -0.0305, 0.0325],
[-0.0663, -0.0673, -0.0140],
[-0.0416, -0.0047, -0.0023],
[ 0.0166, 0.0112, -0.0093],
[-0.0211, 0.0011, 0.0331],
[ 0.1833, 0.1466, 0.2250],
[-0.0368, 0.0370, 0.0295],
[-0.3441, -0.3543, -0.2008],
[-0.0479, -0.0489, -0.0420],
[-0.0660, -0.0153, 0.0800],
[-0.0101, 0.0068, 0.0156],
[-0.0690, -0.0452, -0.0927],
[-0.0145, 0.0041, 0.0015],
[ 0.0421, 0.0451, 0.0373],
[ 0.0504, -0.0483, -0.0356],
[-0.0837, 0.0168, 0.0055]
]
latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388]
def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
]).view(1, self.latent_channels, 1, 1, 1)
class Hunyuan3Dv2(LatentFormat): class Hunyuan3Dv2(LatentFormat):
latent_channels = 64 latent_channels = 64
latent_dimensions = 1 latent_dimensions = 1

View File

@ -201,8 +201,10 @@ class WanAttentionBlock(nn.Module):
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
""" """
# assert e.dtype == torch.float32 # assert e.dtype == torch.float32
if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
# assert e[0].dtype == torch.float32 # assert e[0].dtype == torch.float32
# self-attention # self-attention
@ -325,7 +327,10 @@ class Head(nn.Module):
e(Tensor): Shape [B, C] e(Tensor): Shape [B, C]
""" """
# assert e.dtype == torch.float32 # assert e.dtype == torch.float32
if e.ndim < 3:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1) e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x return x
@ -506,8 +511,9 @@ class WanModel(torch.nn.Module):
# time embeddings # time embeddings
e = self.time_embedding( e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e0 = self.time_projection(e).unflatten(1, (6, self.dim)) e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
# context # context
context = self.text_embedding(context) context = self.text_embedding(context)

View File

@ -52,15 +52,6 @@ class RMS_norm(nn.Module):
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0) x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module): class Resample(nn.Module):
def __init__(self, dim, mode): def __init__(self, dim, mode):
@ -73,11 +64,11 @@ class Resample(nn.Module):
# layers # layers
if mode == 'upsample2d': if mode == 'upsample2d':
self.resample = nn.Sequential( self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'), nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1)) ops.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d': elif mode == 'upsample3d':
self.resample = nn.Sequential( self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'), nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1)) ops.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d( self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
@ -157,29 +148,6 @@ class Resample(nn.Module):
feat_idx[0] += 1 feat_idx[0] += 1
return x return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
@ -494,12 +462,6 @@ class WanVAE(nn.Module):
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout) attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x): def encode(self, x):
self.clear_cache() self.clear_cache()
## cache ## cache
@ -545,18 +507,6 @@ class WanVAE(nn.Module):
self.clear_cache() self.clear_cache()
return out return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self): def clear_cache(self):
self._conv_num = count_conv3d(self.decoder) self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0] self._conv_idx = [0]

726
comfy/ldm/wan/vae2_2.py Normal file
View File

@ -0,0 +1,726 @@
# original version: https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/vae2_2.py
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .vae import AttentionBlock, CausalConv3d, RMS_norm
import comfy.ops
ops = comfy.ops.disable_weight_init
CACHE_T = 2
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in (
"none",
"upsample2d",
"upsample3d",
"downsample2d",
"downsample3d",
)
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
ops.Conv2d(dim, dim, 3, padding=1),
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
ops.Conv2d(dim, dim, 3, padding=1),
# ops.Conv2d(dim, dim//2, 3, padding=1)
)
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
feat_cache[idx] != "Rep"):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
feat_cache[idx] == "Rep"):
cache_x = torch.cat(
[
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2,
)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.resample(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False),
nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False),
nn.SiLU(),
nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1),
)
self.shortcut = (
CausalConv3d(in_dim, out_dim, 1)
if in_dim != out_dim else nn.Identity())
def forward(self, x, feat_cache=None, feat_idx=[0]):
old_x = x
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + self.shortcut(old_x)
def patchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b c f (h q) (w r) -> b (c r q) f h w",
q=patch_size,
r=patch_size,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b (c r q) f h w -> b c f (h q) (w r)",
q=patch_size,
r=patch_size,
)
return x
class AvgDown3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert in_channels * self.factor % out_channels == 0
self.group_size = in_channels * self.factor // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
pad = (0, 0, 0, 0, pad_t, 0)
x = F.pad(x, pad)
B, C, T, H, W = x.shape
x = x.view(
B,
C,
T // self.factor_t,
self.factor_t,
H // self.factor_s,
self.factor_s,
W // self.factor_s,
self.factor_s,
)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(
B,
C * self.factor,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.view(
B,
self.out_channels,
self.group_size,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.mean(dim=2)
return x
class DupUp3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert out_channels * self.factor % in_channels == 0
self.repeats = out_channels * self.factor // in_channels
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = x.view(
x.size(0),
self.out_channels,
self.factor_t,
self.factor_s,
self.factor_s,
x.size(2),
x.size(3),
x.size(4),
)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(
x.size(0),
self.out_channels,
x.size(2) * self.factor_t,
x.size(4) * self.factor_s,
x.size(6) * self.factor_s,
)
if first_chunk:
x = x[:, :, self.factor_t - 1:, :, :]
return x
class Down_ResidualBlock(nn.Module):
def __init__(self,
in_dim,
out_dim,
dropout,
mult,
temperal_downsample=False,
down_flag=False):
super().__init__()
# Shortcut path with downsample
self.avg_shortcut = AvgDown3D(
in_dim,
out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
# Main path with residual blocks and downsample
downsamples = []
for _ in range(mult):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
# Add the final downsample block
if down_flag:
mode = "downsample3d" if temperal_downsample else "downsample2d"
downsamples.append(Resample(out_dim, mode=mode))
self.downsamples = nn.Sequential(*downsamples)
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x
for module in self.downsamples:
x = module(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy)
class Up_ResidualBlock(nn.Module):
def __init__(self,
in_dim,
out_dim,
dropout,
mult,
temperal_upsample=False,
up_flag=False):
super().__init__()
# Shortcut path with upsample
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim,
out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2 if up_flag else 1,
)
else:
self.avg_shortcut = None
# Main path with residual blocks and upsample
upsamples = []
for _ in range(mult):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
# Add the final upsample block
if up_flag:
mode = "upsample3d" if temperal_upsample else "upsample2d"
upsamples.append(Resample(out_dim, mode=mode))
self.upsamples = nn.Sequential(*upsamples)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
x_main = x
for module in self.upsamples:
x_main = module(x_main, feat_cache, feat_idx)
if self.avg_shortcut is not None:
x_shortcut = self.avg_shortcut(x, first_chunk)
return x_main + x_shortcut
else:
return x_main
class Encoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_down_flag = (
temperal_downsample[i]
if i < len(temperal_downsample) else False)
downsamples.append(
Down_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
dropout=dropout,
mult=num_res_blocks,
temperal_downsample=t_down_flag,
down_flag=i != len(dim_mult) - 1,
))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout),
AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout),
)
# # output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False),
nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout),
)
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_up_flag = temperal_upsample[i] if i < len(
temperal_upsample) else False
upsamples.append(
Up_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
dropout=dropout,
mult=num_res_blocks + 1,
temperal_upsample=t_up_flag,
up_flag=i != len(dim_mult) - 1,
))
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False),
nn.SiLU(),
CausalConv3d(out_dim, 12, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, first_chunk)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
count += 1
return count
class WanVAE(nn.Module):
def __init__(
self,
dim=160,
dec_dim=256,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(
dim,
z_dim * 2,
dim_mult,
num_res_blocks,
attn_scales,
self.temperal_downsample,
dropout,
)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(
dec_dim,
z_dim,
dim_mult,
num_res_blocks,
attn_scales,
self.temperal_upsample,
dropout,
)
def encode(self, x):
self.clear_cache()
x = patchify(x, patch_size=2)
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
self.clear_cache()
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
first_chunk=True,
)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
)
out = torch.cat([out, out_], 2)
out = unpatchify(out, patch_size=2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@ -1097,6 +1097,7 @@ class WAN21(BaseModel):
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16]) image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
image = utils.resize_to_batch_size(image, noise.shape[0]) image = utils.resize_to_batch_size(image, noise.shape[0])
if extra_channels != image.shape[1] + 4:
if not self.image_to_video or extra_channels == image.shape[1]: if not self.image_to_video or extra_channels == image.shape[1]:
return image return image
@ -1182,6 +1183,31 @@ class WAN21_Camera(WAN21):
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions) out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
return out return out
class WAN22(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
if denoise_mask is None:
return timestep
temp_ts = (torch.mean(denoise_mask[:, :, :, ::2, ::2], dim=1, keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
return temp_ts
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class Hunyuan3Dv2(BaseModel): class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)

View File

@ -346,7 +346,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config = {} dit_config = {}
dit_config["image_model"] = "wan2.1" dit_config["image_model"] = "wan2.1"
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1] dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4
dit_config["dim"] = dim dit_config["dim"] = dim
dit_config["out_dim"] = out_dim
dit_config["num_heads"] = dim // 128 dit_config["num_heads"] = dim // 128
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0] dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')

View File

@ -392,6 +392,8 @@ def get_torch_device_name(device):
except: except:
allocator_backend = "" allocator_backend = ""
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend) return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
elif device.type == "xpu":
return "{} {}".format(device, torch.xpu.get_device_name(device))
else: else:
return "{}".format(device.type) return "{}".format(device.type)
elif is_intel_xpu(): elif is_intel_xpu():
@ -527,6 +529,8 @@ WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS: if WINDOWS:
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
if args.reserve_vram is not None: if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024

View File

@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.ace.vae.music_dcae_pipeline
import yaml import yaml
@ -420,6 +421,19 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.middle.0.residual.0.gamma" in sd: elif "decoder.middle.0.residual.0.gamma" in sd:
if "decoder.upsamples.0.upsamples.0.residual.2.weight" in sd: # Wan 2.2 VAE
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.latent_channels = 48
ddconfig = {"dim": 160, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae2_2.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
else: # Wan 2.1 VAE
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8) self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)

View File

@ -1059,6 +1059,19 @@ class WAN21_Vace(WAN21_T2V):
out = model_base.WAN21_Vace(self, image_to_video=False, device=device) out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
return out return out
class WAN22_T2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "t2v",
"out_dim": 48,
}
latent_format = latent_formats.Wan22
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN22(self, image_to_video=True, device=device)
return out
class Hunyuan3Dv2(supported_models_base.BASE): class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "hunyuan3d2", "image_model": "hunyuan3d2",
@ -1217,6 +1230,6 @@ class Omnigen2(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2] models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -698,6 +698,26 @@ def resize_to_batch_size(tensor, batch_size):
return output return output
def resize_list_to_batch_size(l, batch_size):
in_batch_size = len(l)
if in_batch_size == batch_size or in_batch_size == 0:
return l
if batch_size <= 1:
return l[:batch_size]
output = []
if batch_size < in_batch_size:
scale = (in_batch_size - 1) / (batch_size - 1)
for i in range(batch_size):
output.append(l[min(round(i * scale), in_batch_size - 1)])
else:
scale = in_batch_size / batch_size
for i in range(batch_size):
output.append(l[min(math.floor((i + 0.5) * scale), in_batch_size - 1)])
return output
def convert_sd_to(state_dict, dtype): def convert_sd_to(state_dict, dtype):
keys = list(state_dict.keys()) keys = list(state_dict.keys())
for k in keys: for k in keys:

View File

@ -2,7 +2,7 @@
## Introduction ## Introduction
Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview#api-nodes). Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview).
## Development ## Development

View File

@ -2,7 +2,10 @@ import logging
from typing import Any, Callable, Optional, TypeVar from typing import Any, Callable, Optional, TypeVar
import random import random
import torch import torch
from comfy_api_nodes.util.validation_utils import get_image_dimensions, validate_image_dimensions, validate_video_dimensions from comfy_api_nodes.util.validation_utils import (
get_image_dimensions,
validate_image_dimensions,
)
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import (
@ -10,7 +13,7 @@ from comfy_api_nodes.apis import (
MoonvalleyTextToVideoInferenceParams, MoonvalleyTextToVideoInferenceParams,
MoonvalleyVideoToVideoInferenceParams, MoonvalleyVideoToVideoInferenceParams,
MoonvalleyVideoToVideoRequest, MoonvalleyVideoToVideoRequest,
MoonvalleyPromptResponse MoonvalleyPromptResponse,
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.apis.client import (
ApiEndpoint, ApiEndpoint,
@ -54,20 +57,26 @@ MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000 MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
R = TypeVar("R") R = TypeVar("R")
class MoonvalleyApiError(Exception): class MoonvalleyApiError(Exception):
"""Base exception for Moonvalley API errors.""" """Base exception for Moonvalley API errors."""
pass pass
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool: def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
"""Verifies that the initial response contains a task ID.""" """Verifies that the initial response contains a task ID."""
return bool(response.id) return bool(response.id)
def validate_task_creation_response(response) -> None: def validate_task_creation_response(response) -> None:
if not is_valid_task_creation_response(response): if not is_valid_task_creation_response(response):
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}" error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
logging.error(error_msg) logging.error(error_msg)
raise MoonvalleyApiError(error_msg) raise MoonvalleyApiError(error_msg)
def get_video_from_response(response): def get_video_from_response(response):
video = response.output_url video = response.output_url
logging.info( logging.info(
@ -102,16 +111,17 @@ def poll_until_finished(
poll_interval=16.0, poll_interval=16.0,
failed_statuses=["error"], failed_statuses=["error"],
status_extractor=lambda response: ( status_extractor=lambda response: (
response.status response.status if response and response.status else None
if response and response.status
else None
), ),
auth_kwargs=auth_kwargs, auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor, result_url_extractor=result_url_extractor,
node_id=node_id, node_id=node_id,
).execute() ).execute()
def validate_prompts(prompt:str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH):
def validate_prompts(
prompt: str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH
):
"""Verifies that the prompt isn't empty and that neither prompt is too long.""" """Verifies that the prompt isn't empty and that neither prompt is too long."""
if not prompt: if not prompt:
raise ValueError("Positive prompt is empty") raise ValueError("Positive prompt is empty")
@ -123,6 +133,7 @@ def validate_prompts(prompt:str, negative_prompt: str, max_length=MOONVALLEY_MAR
) )
return True return True
def validate_input_media(width, height, with_frame_conditioning, num_frames_in=None): def validate_input_media(width, height, with_frame_conditioning, num_frames_in=None):
# inference validation # inference validation
# T = num_frames # T = num_frames
@ -130,9 +141,7 @@ def validate_input_media(width, height, with_frame_conditioning, num_frames_in=N
# with image conditioning: H*W must be divisible by 8192 # with image conditioning: H*W must be divisible by 8192
# without image conditioning: T divisible by 32 # without image conditioning: T divisible by 32
if num_frames_in and not num_frames_in % 16 == 0: if num_frames_in and not num_frames_in % 16 == 0:
return False, ( return False, ("The input video total frame count must be divisible by 16!")
"The input video total frame count must be divisible by 16!"
)
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
return False, ( return False, (
@ -147,12 +156,12 @@ def validate_input_media(width, height, with_frame_conditioning, num_frames_in=N
) )
else: else:
if num_frames_in and not num_frames_in % 32 == 0: if num_frames_in and not num_frames_in % 32 == 0:
return False, ( return False, ("The input video total frame count must be divisible by 32!")
"The input video total frame count must be divisible by 32!"
)
def validate_input_image(image: torch.Tensor, with_frame_conditioning: bool=False) -> None: def validate_input_image(
image: torch.Tensor, with_frame_conditioning: bool = False
) -> None:
""" """
Validates the input image adheres to the expectations of the API: Validates the input image adheres to the expectations of the API:
- The image resolution should not be less than 300*300px - The image resolution should not be less than 300*300px
@ -161,41 +170,81 @@ def validate_input_image(image: torch.Tensor, with_frame_conditioning: bool=Fals
""" """
height, width = get_image_dimensions(image) height, width = get_image_dimensions(image)
validate_input_media(width, height, with_frame_conditioning) validate_input_media(width, height, with_frame_conditioning)
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH) validate_image_dimensions(
image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH
)
def validate_input_video(video: VideoInput, num_frames_out: int, with_frame_conditioning: bool=False):
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.
Args:
video: Input video to validate
Returns:
Validated and potentially trimmed video
Raises:
ValueError: If video doesn't meet requirements
MoonvalleyApiError: If video duration is too short
"""
width, height = _get_video_dimensions(video)
_validate_video_dimensions(width, height)
_validate_container_format(video)
return _validate_and_trim_duration(video)
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
"""Extracts video dimensions with error handling."""
try: try:
width, height = video.get_dimensions() return video.get_dimensions()
except Exception as e: except Exception as e:
logging.error("Error getting dimensions of video: %s", e) logging.error("Error getting dimensions of video: %s", e)
raise ValueError(f"Cannot get video dimensions: {e}") from e raise ValueError(f"Cannot get video dimensions: {e}") from e
validate_input_media(width, height, with_frame_conditioning)
validate_video_dimensions(video, min_width=MIN_VID_WIDTH, min_height=MIN_VID_HEIGHT, max_width=MAX_VID_WIDTH, max_height=MAX_VID_HEIGHT)
trimmed_video = validate_input_video_length(video, num_frames_out) def _validate_video_dimensions(width: int, height: int) -> None:
return trimmed_video """Validates video dimensions meet Moonvalley V2V requirements."""
supported_resolutions = {
(1920, 1080), (1080, 1920), (1152, 1152),
(1536, 1152), (1152, 1536)
}
if (width, height) not in supported_resolutions:
supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)])
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
def validate_input_video_length(video: VideoInput, num_frames: int): def _validate_container_format(video: VideoInput) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
if video.get_duration() > 60:
raise MoonvalleyApiError("Input Video lenth should be less than 1min. Please trim.")
if num_frames == 128: def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
if video.get_duration() < 5: """Validates video duration and trims to 5 seconds if needed."""
raise MoonvalleyApiError("Input Video length is less than 5s. Please use a video longer than or equal to 5s.") duration = video.get_duration()
if video.get_duration() > 5: _validate_minimum_duration(duration)
# trim video to 5s return _trim_if_too_long(video, duration)
video = trim_video(video, 5)
if num_frames == 256:
if video.get_duration() < 10: def _validate_minimum_duration(duration: float) -> None:
raise MoonvalleyApiError("Input Video length is less than 10s. Please use a video longer than or equal to 10s.") """Ensures video is at least 5 seconds long."""
if video.get_duration() > 10: if duration < 5:
# trim video to 10s raise MoonvalleyApiError("Input video must be at least 5 seconds long.")
video = trim_video(video, 10)
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
"""Trims video to 5 seconds if longer."""
if duration > 5:
return trim_video(video, 5)
return video return video
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
""" """
Returns a new VideoInput object trimmed from the beginning to the specified duration, Returns a new VideoInput object trimmed from the beginning to the specified duration,
@ -219,8 +268,8 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
input_source = video.get_stream_source() input_source = video.get_stream_source()
# Open containers # Open containers
input_container = av.open(input_source, mode='r') input_container = av.open(input_source, mode="r")
output_container = av.open(output_buffer, mode='w', format='mp4') output_container = av.open(output_buffer, mode="w", format="mp4")
# Set up output streams for re-encoding # Set up output streams for re-encoding
video_stream = None video_stream = None
@ -230,25 +279,33 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
logging.info(f"Found stream: type={stream.type}, class={type(stream)}") logging.info(f"Found stream: type={stream.type}, class={type(stream)}")
if isinstance(stream, av.VideoStream): if isinstance(stream, av.VideoStream):
# Create output video stream with same parameters # Create output video stream with same parameters
video_stream = output_container.add_stream('h264', rate=stream.average_rate) video_stream = output_container.add_stream(
"h264", rate=stream.average_rate
)
video_stream.width = stream.width video_stream.width = stream.width
video_stream.height = stream.height video_stream.height = stream.height
video_stream.pix_fmt = 'yuv420p' video_stream.pix_fmt = "yuv420p"
logging.info(f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps") logging.info(
f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps"
)
elif isinstance(stream, av.AudioStream): elif isinstance(stream, av.AudioStream):
# Create output audio stream with same parameters # Create output audio stream with same parameters
audio_stream = output_container.add_stream('aac', rate=stream.sample_rate) audio_stream = output_container.add_stream(
"aac", rate=stream.sample_rate
)
audio_stream.sample_rate = stream.sample_rate audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout audio_stream.layout = stream.layout
logging.info(f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels") logging.info(
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
)
# Calculate target frame count that's divisible by 32 # Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate fps = input_container.streams.video[0].average_rate
estimated_frames = int(duration_sec * fps) estimated_frames = int(duration_sec * fps)
target_frames = (estimated_frames // 32) * 32 # Round down to nearest multiple of 32 target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
if target_frames == 0: if target_frames == 0:
raise ValueError("Video too short: need at least 32 frames for Moonvalley") raise ValueError("Video too short: need at least 16 frames for Moonvalley")
frame_count = 0 frame_count = 0
audio_frame_count = 0 audio_frame_count = 0
@ -268,7 +325,9 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
for packet in video_stream.encode(): for packet in video_stream.encode():
output_container.mux(packet) output_container.mux(packet)
logging.info(f"Encoded {frame_count} video frames (target: {target_frames})") logging.info(
f"Encoded {frame_count} video frames (target: {target_frames})"
)
# Decode and re-encode audio frames # Decode and re-encode audio frames
if audio_stream: if audio_stream:
@ -292,7 +351,6 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
output_container.close() output_container.close()
input_container.close() input_container.close()
# Return as VideoFromFile using the buffer # Return as VideoFromFile using the buffer
output_buffer.seek(0) output_buffer.seek(0)
return VideoFromFile(output_buffer) return VideoFromFile(output_buffer)
@ -305,6 +363,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
output_container.close() output_container.close()
raise RuntimeError(f"Failed to trim video: {str(e)}") from e raise RuntimeError(f"Failed to trim video: {str(e)}") from e
# --- BaseMoonvalleyVideoNode --- # --- BaseMoonvalleyVideoNode ---
class BaseMoonvalleyVideoNode: class BaseMoonvalleyVideoNode:
def parseWidthHeightFromRes(self, resolution: str): def parseWidthHeightFromRes(self, resolution: str):
@ -313,8 +372,8 @@ class BaseMoonvalleyVideoNode:
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080}, "16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
"9:16 (1080 x 1920)": {"width": 1080, "height": 1920}, "9:16 (1080 x 1920)": {"width": 1080, "height": 1920},
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152}, "1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
"4:3 (1440 x 1080)": {"width": 1440, "height": 1080}, "4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
"3:4 (1080 x 1440)": {"width": 1080, "height": 1440}, "3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080}, "21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
} }
if resolution in res_map: if resolution in res_map:
@ -328,7 +387,7 @@ class BaseMoonvalleyVideoNode:
"Motion Transfer": "motion_control", "Motion Transfer": "motion_control",
"Canny": "canny_control", "Canny": "canny_control",
"Pose Transfer": "pose_control", "Pose Transfer": "pose_control",
"Depth": "depth_control" "Depth": "depth_control",
} }
if value in control_map: if value in control_map:
return control_map[value] return control_map[value]
@ -355,31 +414,63 @@ class BaseMoonvalleyVideoNode:
return { return {
"required": { "required": {
"prompt": model_field_to_node_input( "prompt": model_field_to_node_input(
IO.STRING, MoonvalleyTextToVideoRequest, "prompt_text", IO.STRING,
multiline=True MoonvalleyTextToVideoRequest,
"prompt_text",
multiline=True,
), ),
"negative_prompt": model_field_to_node_input( "negative_prompt": model_field_to_node_input(
IO.STRING, IO.STRING,
MoonvalleyTextToVideoInferenceParams, MoonvalleyTextToVideoInferenceParams,
"negative_prompt", "negative_prompt",
multiline=True, multiline=True,
default="gopro, bright, contrast, static, overexposed, bright, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, contrast, saturated, vibrant, glowing, cross dissolve, texture, videogame, saturation, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, transition, dissolve, cross-dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring, static", default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts",
), ),
"resolution": (
"resolution": (IO.COMBO, { IO.COMBO,
"options": ["16:9 (1920 x 1080)", {
"options": [
"16:9 (1920 x 1080)",
"9:16 (1080 x 1920)", "9:16 (1080 x 1920)",
"1:1 (1152 x 1152)", "1:1 (1152 x 1152)",
"4:3 (1440 x 1080)", "4:3 (1440 x 1080)",
"3:4 (1080 x 1440)", "3:4 (1080 x 1440)",
"21:9 (2560 x 1080)"], "21:9 (2560 x 1080)",
],
"default": "16:9 (1920 x 1080)", "default": "16:9 (1920 x 1080)",
"tooltip": "Resolution of the output video", "tooltip": "Resolution of the output video",
}), },
),
# "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}), # "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}),
"prompt_adherence": model_field_to_node_input(IO.FLOAT,MoonvalleyTextToVideoInferenceParams,"guidance_scale",default=7.0, step=1, min=1, max=20), "prompt_adherence": model_field_to_node_input(
"seed": model_field_to_node_input(IO.INT,MoonvalleyTextToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True), IO.FLOAT,
"steps": model_field_to_node_input(IO.INT, MoonvalleyTextToVideoInferenceParams, "steps", default=100, min=1, max=100), MoonvalleyTextToVideoInferenceParams,
"guidance_scale",
default=7.0,
step=1,
min=1,
max=20,
),
"seed": model_field_to_node_input(
IO.INT,
MoonvalleyTextToVideoInferenceParams,
"seed",
default=random.randint(0, 2**32 - 1),
min=0,
max=4294967295,
step=1,
display="number",
tooltip="Random seed value",
control_after_generate=True,
),
"steps": model_field_to_node_input(
IO.INT,
MoonvalleyTextToVideoInferenceParams,
"steps",
default=100,
min=1,
max=100,
),
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
@ -393,7 +484,7 @@ class BaseMoonvalleyVideoNode:
"image_url", "image_url",
tooltip="The reference image used to generate the video", tooltip="The reference image used to generate the video",
), ),
} },
} }
RETURN_TYPES = ("STRING",) RETURN_TYPES = ("STRING",)
@ -404,6 +495,7 @@ class BaseMoonvalleyVideoNode:
def generate(self, **kwargs): def generate(self, **kwargs):
return None return None
# --- MoonvalleyImg2VideoNode --- # --- MoonvalleyImg2VideoNode ---
class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
@ -415,11 +507,12 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
RETURN_NAMES = ("video",) RETURN_NAMES = ("video",)
DESCRIPTION = "Moonvalley Marey Image to Video Node" DESCRIPTION = "Moonvalley Marey Image to Video Node"
def generate(self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs): def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
image = kwargs.get("image", None) image = kwargs.get("image", None)
if (image is None): if image is None:
raise MoonvalleyApiError("image is required") raise MoonvalleyApiError("image is required")
total_frames = get_total_frames_from_length()
validate_input_image(image, True) validate_input_image(image, True)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
@ -430,27 +523,28 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
steps=kwargs.get("steps"), steps=kwargs.get("steps"),
seed=kwargs.get("seed"), seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"), guidance_scale=kwargs.get("prompt_adherence"),
num_frames=total_frames, num_frames=128,
width=width_height.get("width"), width=width_height.get("width"),
height=width_height.get("height"), height=width_height.get("height"),
use_negative_prompts=True use_negative_prompts=True,
) )
"""Upload image to comfy backend to have a URL available for further processing""" """Upload image to comfy backend to have a URL available for further processing"""
# Get MIME type from tensor - assuming PNG format for image tensors # Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png" mime_type = "image/png"
image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type)[0] image_url = upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
)[0]
request = MoonvalleyTextToVideoRequest( request = MoonvalleyTextToVideoRequest(
image_url=image_url, image_url=image_url, prompt_text=prompt, inference_params=inference_params
prompt_text=prompt,
inference_params=inference_params
) )
initial_operation = SynchronousOperation( initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, endpoint=ApiEndpoint(
path=API_IMG2VIDEO_ENDPOINT,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=MoonvalleyTextToVideoRequest, request_model=MoonvalleyTextToVideoRequest,
response_model=MoonvalleyPromptResponse response_model=MoonvalleyPromptResponse,
), ),
request=request, request=request,
auth_kwargs=kwargs, auth_kwargs=kwargs,
@ -465,6 +559,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
video = download_url_to_video_output(final_response.output_url) video = download_url_to_video_output(final_response.output_url)
return (video,) return (video,)
# --- MoonvalleyVid2VidNode --- # --- MoonvalleyVid2VidNode ---
class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
def __init__(self): def __init__(self):
@ -472,14 +567,28 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
@classmethod @classmethod
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
input_types = super().INPUT_TYPES() return {
for param in ["resolution", "image"]: "required": {
if param in input_types["required"]: "prompt": model_field_to_node_input(
del input_types["required"][param] IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text",
if param in input_types["optional"]: multiline=True
del input_types["optional"][param] ),
input_types["optional"] = { "negative_prompt": model_field_to_node_input(
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Input a 5s video for 128 frames and a 10s video for 256 frames. Longer videos will be trimmed automatically."}), IO.STRING,
MoonvalleyVideoToVideoInferenceParams,
"negative_prompt",
multiline=True,
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts"
),
"seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
"optional": {
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}),
"control_type": ( "control_type": (
["Motion Transfer", "Pose Transfer"], ["Motion Transfer", "Pose Transfer"],
{"default": "Motion Transfer"}, {"default": "Motion Transfer"},
@ -495,24 +604,22 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
}, },
) )
} }
}
return input_types
RETURN_TYPES = ("VIDEO",) RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",) RETURN_NAMES = ("video",)
def generate(self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs): def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
video = kwargs.get("video") video = kwargs.get("video")
num_frames = get_total_frames_from_length()
if not video: if not video:
raise MoonvalleyApiError("video is required") raise MoonvalleyApiError("video is required")
"""Validate video input"""
video_url = "" video_url = ""
if video: if video:
validated_video = validate_input_video(video, num_frames, False) validated_video = validate_video_to_video_input(video)
video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs) video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
control_type = kwargs.get("control_type") control_type = kwargs.get("control_type")
@ -520,12 +627,16 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
"""Validate prompts and inference input""" """Validate prompts and inference input"""
validate_prompts(prompt, negative_prompt) validate_prompts(prompt, negative_prompt)
# Only include motion_intensity for Motion Transfer
control_params = {}
if control_type == "Motion Transfer" and motion_intensity is not None:
control_params['motion_intensity'] = motion_intensity
inference_params=MoonvalleyVideoToVideoInferenceParams( inference_params=MoonvalleyVideoToVideoInferenceParams(
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"), seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"), control_params=control_params
control_params={'motion_intensity': motion_intensity}
) )
control = self.parseControlParameter(control_type) control = self.parseControlParameter(control_type)
@ -534,14 +645,15 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
control_type=control, control_type=control,
video_url=video_url, video_url=video_url,
prompt_text=prompt, prompt_text=prompt,
inference_params=inference_params inference_params=inference_params,
) )
initial_operation = SynchronousOperation( initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, endpoint=ApiEndpoint(
path=API_VIDEO2VIDEO_ENDPOINT,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=MoonvalleyVideoToVideoRequest, request_model=MoonvalleyVideoToVideoRequest,
response_model=MoonvalleyPromptResponse response_model=MoonvalleyPromptResponse,
), ),
request=request, request=request,
auth_kwargs=kwargs, auth_kwargs=kwargs,
@ -558,6 +670,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
return (video,) return (video,)
# --- MoonvalleyTxt2VideoNode --- # --- MoonvalleyTxt2VideoNode ---
class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode): class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
def __init__(self): def __init__(self):
@ -575,30 +688,31 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
del input_types["optional"][param] del input_types["optional"][param]
return input_types return input_types
def generate(self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs): def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution")) width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
num_frames = get_total_frames_from_length()
inference_params=MoonvalleyTextToVideoInferenceParams( inference_params=MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
steps=kwargs.get("steps"), steps=kwargs.get("steps"),
seed=kwargs.get("seed"), seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"), guidance_scale=kwargs.get("prompt_adherence"),
num_frames=num_frames, num_frames=128,
width=width_height.get("width"), width=width_height.get("width"),
height=width_height.get("height"), height=width_height.get("height"),
) )
request = MoonvalleyTextToVideoRequest( request = MoonvalleyTextToVideoRequest(
prompt_text=prompt, prompt_text=prompt, inference_params=inference_params
inference_params=inference_params
) )
initial_operation = SynchronousOperation( initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, endpoint=ApiEndpoint(
path=API_TXT2VIDEO_ENDPOINT,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=MoonvalleyTextToVideoRequest, request_model=MoonvalleyTextToVideoRequest,
response_model=MoonvalleyPromptResponse response_model=MoonvalleyPromptResponse,
), ),
request=request, request=request,
auth_kwargs=kwargs, auth_kwargs=kwargs,
@ -615,25 +729,15 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
return (video,) return (video,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode, "MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode,
"MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode, "MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode,
# "MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode, "MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video", "MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video",
"MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video", "MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video",
# "MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video", "MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video",
} }
def get_total_frames_from_length(length="5s"):
# if length == '5s':
# return 128
# elif length == '10s':
# return 256
return 128
# else:
# raise MoonvalleyApiError("length is required")

View File

@ -1,3 +1,4 @@
import math
import nodes import nodes
import node_helpers import node_helpers
import torch import torch
@ -5,7 +6,9 @@ import comfy.model_management
import comfy.utils import comfy.utils
import comfy.latent_formats import comfy.latent_formats
import comfy.clip_vision import comfy.clip_vision
import json
import numpy as np
from typing import Tuple
class WanImageToVideo: class WanImageToVideo:
@classmethod @classmethod
@ -383,7 +386,350 @@ class WanPhantomSubjectToVideo:
out_latent["samples"] = latent out_latent["samples"] = latent
return (positive, cond2, negative, out_latent) return (positive, cond2, negative, out_latent)
def parse_json_tracks(tracks):
"""Parse JSON track data into a standardized format"""
tracks_data = []
try:
# If tracks is a string, try to parse it as JSON
if isinstance(tracks, str):
parsed = json.loads(tracks.replace("'", '"'))
tracks_data.extend(parsed)
else:
# If tracks is a list of strings, parse each one
for track_str in tracks:
parsed = json.loads(track_str.replace("'", '"'))
tracks_data.append(parsed)
# Check if we have a single track (dict with x,y) or a list of tracks
if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]:
# Single track detected, wrap it in a list
tracks_data = [tracks_data]
elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]:
# Already a list of tracks, nothing to do
pass
else:
# Unexpected format
pass
except json.JSONDecodeError:
tracks_data = []
return tracks_data
def process_tracks(tracks_np: np.ndarray, frame_size: Tuple[int, int], num_frames, quant_multi: int = 8, **kwargs):
# tracks: shape [t, h, w, 3] => samples align with 24 fps, model trained with 16 fps.
# frame_size: tuple (W, H)
tracks = torch.from_numpy(tracks_np).float()
if tracks.shape[1] == 121:
tracks = torch.permute(tracks, (1, 0, 2, 3))
tracks, visibles = tracks[..., :2], tracks[..., 2:3]
short_edge = min(*frame_size)
frame_center = torch.tensor([*frame_size]).type_as(tracks) / 2
tracks = tracks - frame_center
tracks = tracks / short_edge * 2
visibles = visibles * 2 - 1
trange = torch.linspace(-1, 1, tracks.shape[0]).view(-1, 1, 1, 1).expand(*visibles.shape)
out_ = torch.cat([trange, tracks, visibles], dim=-1).view(121, -1, 4)
out_0 = out_[:1]
out_l = out_[1:] # 121 => 120 | 1
a = 120 // math.gcd(120, num_frames)
b = num_frames // math.gcd(120, num_frames)
out_l = torch.repeat_interleave(out_l, b, dim=0)[1::a] # 120 => 120 * b => 120 * b / a == F
final_result = torch.cat([out_0, out_l], dim=0)
return final_result
FIXED_LENGTH = 121
def pad_pts(tr):
"""Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating."""
pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32)
n = pts.shape[0]
if n < FIXED_LENGTH:
pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32)
pts = np.vstack((pts, pad))
else:
pts = pts[:FIXED_LENGTH]
return pts.reshape(FIXED_LENGTH, 1, 3)
def ind_sel(target: torch.Tensor, ind: torch.Tensor, dim: int = 1):
"""Index selection utility function"""
assert (
len(ind.shape) > dim
), "Index must have the target dim, but get dim: %d, ind shape: %s" % (dim, str(ind.shape))
target = target.expand(
*tuple(
[ind.shape[k] if target.shape[k] == 1 else -1 for k in range(dim)]
+ [
-1,
]
* (len(target.shape) - dim)
)
)
ind_pad = ind
if len(target.shape) > dim + 1:
for _ in range(len(target.shape) - (dim + 1)):
ind_pad = ind_pad.unsqueeze(-1)
ind_pad = ind_pad.expand(*(-1,) * (dim + 1), *target.shape[(dim + 1) : :])
return torch.gather(target, dim=dim, index=ind_pad)
def merge_final(vert_attr: torch.Tensor, weight: torch.Tensor, vert_assign: torch.Tensor):
"""Merge vertex attributes with weights"""
target_dim = len(vert_assign.shape) - 1
if len(vert_attr.shape) == 2:
assert vert_attr.shape[0] > vert_assign.max()
new_shape = [1] * target_dim + list(vert_attr.shape)
tensor = vert_attr.reshape(new_shape)
sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim)
else:
assert vert_attr.shape[1] > vert_assign.max()
new_shape = [vert_attr.shape[0]] + [1] * (target_dim - 1) + list(vert_attr.shape[1:])
tensor = vert_attr.reshape(new_shape)
sel_attr = ind_sel(tensor, vert_assign.type(torch.long), dim=target_dim)
final_attr = torch.sum(sel_attr * weight.unsqueeze(-1), dim=-2)
return final_attr
def _patch_motion_single(
tracks: torch.FloatTensor, # (B, T, N, 4)
vid: torch.FloatTensor, # (C, T, H, W)
temperature: float,
vae_divide: tuple,
topk: int,
):
"""Apply motion patching based on tracks"""
_, T, H, W = vid.shape
N = tracks.shape[2]
_, tracks_xy, visible = torch.split(
tracks, [1, 2, 1], dim=-1
) # (B, T, N, 2) | (B, T, N, 1)
tracks_n = tracks_xy / torch.tensor([W / min(H, W), H / min(H, W)], device=tracks_xy.device)
tracks_n = tracks_n.clamp(-1, 1)
visible = visible.clamp(0, 1)
xx = torch.linspace(-W / min(H, W), W / min(H, W), W)
yy = torch.linspace(-H / min(H, W), H / min(H, W), H)
grid = torch.stack(torch.meshgrid(yy, xx, indexing="ij")[::-1], dim=-1).to(
tracks_xy.device
)
tracks_pad = tracks_xy[:, 1:]
visible_pad = visible[:, 1:]
visible_align = visible_pad.view(T - 1, 4, *visible_pad.shape[2:]).sum(1)
tracks_align = (tracks_pad * visible_pad).view(T - 1, 4, *tracks_pad.shape[2:]).sum(
1
) / (visible_align + 1e-5)
dist_ = (
(tracks_align[:, None, None] - grid[None, :, :, None]).pow(2).sum(-1)
) # T, H, W, N
weight = torch.exp(-dist_ * temperature) * visible_align.clamp(0, 1).view(
T - 1, 1, 1, N
)
vert_weight, vert_index = torch.topk(
weight, k=min(topk, weight.shape[-1]), dim=-1
)
grid_mode = "bilinear"
point_feature = torch.nn.functional.grid_sample(
vid.permute(1, 0, 2, 3)[:1],
tracks_n[:, :1].type(vid.dtype),
mode=grid_mode,
padding_mode="zeros",
align_corners=False,
)
point_feature = point_feature.squeeze(0).squeeze(1).permute(1, 0) # N, C=16
out_feature = merge_final(point_feature, vert_weight, vert_index).permute(3, 0, 1, 2) # T - 1, H, W, C => C, T - 1, H, W
out_weight = vert_weight.sum(-1) # T - 1, H, W
# out feature -> already soft weighted
mix_feature = out_feature + vid[:, 1:] * (1 - out_weight.clamp(0, 1))
out_feature_full = torch.cat([vid[:, :1], mix_feature], dim=1) # C, T, H, W
out_mask_full = torch.cat([torch.ones_like(out_weight[:1]), out_weight], dim=0) # T, H, W
return out_mask_full[None].expand(vae_divide[0], -1, -1, -1), out_feature_full
def patch_motion(
tracks: torch.FloatTensor, # (B, TB, T, N, 4)
vid: torch.FloatTensor, # (C, T, H, W)
temperature: float = 220.0,
vae_divide: tuple = (4, 16),
topk: int = 2,
):
B = len(tracks)
# Process each batch separately
out_masks = []
out_features = []
for b in range(B):
mask, feature = _patch_motion_single(
tracks[b], # (T, N, 4)
vid[b], # (C, T, H, W)
temperature,
vae_divide,
topk
)
out_masks.append(mask)
out_features.append(feature)
# Stack results: (B, C, T, H, W)
out_mask_full = torch.stack(out_masks, dim=0)
out_feature_full = torch.stack(out_features, dim=0)
return out_mask_full, out_feature_full
class WanTrackToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"tracks": ("STRING", {"multiline": True, "default": "[]"}),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
"topk": ("INT", {"default": 2, "min": 1, "max": 10}),
"start_image": ("IMAGE", ),
},
"optional": {
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, tracks, width, height, length, batch_size,
temperature, topk, start_image=None, clip_vision_output=None):
tracks_data = parse_json_tracks(tracks)
if not tracks_data:
return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output)
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
device=comfy.model_management.intermediate_device())
if isinstance(tracks_data[0][0], dict):
tracks_data = [tracks_data]
processed_tracks = []
for batch in tracks_data:
arrs = []
for track in batch:
pts = pad_pts(track)
arrs.append(pts)
tracks_np = np.stack(arrs, axis=0)
processed_tracks.append(process_tracks(tracks_np, (width, height), length - 1).unsqueeze(0))
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:batch_size].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
videos = torch.ones((start_image.shape[0], length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
for i in range(start_image.shape[0]):
videos[i, 0] = start_image[i]
latent_videos = []
videos = comfy.utils.resize_to_batch_size(videos, batch_size)
for i in range(batch_size):
latent_videos += [vae.encode(videos[i, :, :, :, :3])]
y = torch.cat(latent_videos, dim=0)
# Scale latent since patch_motion is non-linear
y = comfy.latent_formats.Wan21().process_in(y)
processed_tracks = comfy.utils.resize_list_to_batch_size(processed_tracks, batch_size)
res = patch_motion(
processed_tracks, y, temperature=temperature, topk=topk, vae_divide=(4, 16)
)
mask, concat_latent_image = res
concat_latent_image = comfy.latent_formats.Wan21().process_out(concat_latent_image)
mask = -mask + 1.0 # Invert mask to match expected format
positive = node_helpers.conditioning_set_values(positive,
{"concat_mask": mask,
"concat_latent_image": concat_latent_image})
negative = node_helpers.conditioning_set_values(negative,
{"concat_mask": mask,
"concat_latent_image": concat_latent_image})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
class Wan22ImageToVideoLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": {"vae": ("VAE", ),
"width": ("INT", {"default": 1280, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 704, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 49, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"start_image": ("IMAGE", ),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, vae, width, height, length, batch_size, start_image=None):
latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
if start_image is None:
out_latent = {}
out_latent["samples"] = latent
return (out_latent,)
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
latent_temp = vae.encode(start_image)
latent[:, :, :latent_temp.shape[-3]] = latent_temp
mask[:, :, :latent_temp.shape[-3]] *= 0.0
out_latent = {}
latent_format = comfy.latent_formats.Wan22()
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return (out_latent,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"WanTrackToVideo": WanTrackToVideo,
"WanImageToVideo": WanImageToVideo, "WanImageToVideo": WanImageToVideo,
"WanFunControlToVideo": WanFunControlToVideo, "WanFunControlToVideo": WanFunControlToVideo,
"WanFunInpaintToVideo": WanFunInpaintToVideo, "WanFunInpaintToVideo": WanFunInpaintToVideo,
@ -392,4 +738,5 @@ NODE_CLASS_MAPPINGS = {
"TrimVideoLatent": TrimVideoLatent, "TrimVideoLatent": TrimVideoLatent,
"WanCameraImageToVideo": WanCameraImageToVideo, "WanCameraImageToVideo": WanCameraImageToVideo,
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo, "WanPhantomSubjectToVideo": WanPhantomSubjectToVideo,
"Wan22ImageToVideoLatent": Wan22ImageToVideoLatent,
} }

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.45" __version__ = "0.3.46"

View File

@ -1176,7 +1176,7 @@ class PromptQueue:
return True return True
return False return False
def get_history(self, prompt_id=None, max_items=None, offset=-1): def get_history(self, prompt_id=None, max_items=None, offset=-1, map_function=None):
with self.mutex: with self.mutex:
if prompt_id is None: if prompt_id is None:
out = {} out = {}
@ -1185,13 +1185,21 @@ class PromptQueue:
offset = len(self.history) - max_items offset = len(self.history) - max_items
for k in self.history: for k in self.history:
if i >= offset: if i >= offset:
out[k] = self.history[k] p = self.history[k]
if map_function is not None:
p = map_function(p)
out[k] = p
if max_items is not None and len(out) >= max_items: if max_items is not None and len(out) >= max_items:
break break
i += 1 i += 1
return out return out
elif prompt_id in self.history: elif prompt_id in self.history:
return {prompt_id: copy.deepcopy(self.history[prompt_id])} p = self.history[prompt_id]
if map_function is None:
p = copy.deepcopy(p)
else:
p = map_function(p)
return {prompt_id: p}
else: else:
return {} return {}

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.45" version = "0.3.46"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.23.4 comfyui-frontend-package==1.23.4
comfyui-workflow-templates==0.1.39 comfyui-workflow-templates==0.1.41
comfyui-embedded-docs==0.2.4 comfyui-embedded-docs==0.2.4
torch torch
torchsde torchsde