mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-02 19:26:43 +00:00
Merge branch 'master' into worksplit-multigpu
This commit is contained in:
commit
5d5024296d
@ -55,7 +55,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
## Features
|
## Features
|
||||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||||
- Image Models
|
- Image Models
|
||||||
- SD1.x, SD2.x,
|
- SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
|
||||||
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||||
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
||||||
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
|
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
|
||||||
@ -77,6 +77,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||||
|
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
||||||
- Audio Models
|
- Audio Models
|
||||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||||
@ -84,9 +85,9 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
||||||
- Asynchronous Queue system
|
- Asynchronous Queue system
|
||||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
- Smart memory management: can automatically run large models on GPUs with as low as 1GB vram with smart offloading.
|
||||||
- Works even if you don't have a GPU with: ```--cpu``` (slow)
|
- Works even if you don't have a GPU with: ```--cpu``` (slow)
|
||||||
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
|
- Can load ckpt and safetensors: All in one checkpoints or standalone diffusion models, VAEs and CLIP models.
|
||||||
- Safe loading of ckpt, pt, pth, etc.. files.
|
- Safe loading of ckpt, pt, pth, etc.. files.
|
||||||
- Embeddings/Textual inversion
|
- Embeddings/Textual inversion
|
||||||
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||||
@ -98,7 +99,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
|
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
|
||||||
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
|
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
|
||||||
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
||||||
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
|
||||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||||
|
@ -457,6 +457,82 @@ class Wan21(LatentFormat):
|
|||||||
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
latents_std = self.latents_std.to(latent.device, latent.dtype)
|
||||||
return latent * latents_std / self.scale_factor + latents_mean
|
return latent * latents_std / self.scale_factor + latents_mean
|
||||||
|
|
||||||
|
class Wan22(Wan21):
|
||||||
|
latent_channels = 48
|
||||||
|
latent_dimensions = 3
|
||||||
|
|
||||||
|
latent_rgb_factors = [
|
||||||
|
[ 0.0119, 0.0103, 0.0046],
|
||||||
|
[-0.1062, -0.0504, 0.0165],
|
||||||
|
[ 0.0140, 0.0409, 0.0491],
|
||||||
|
[-0.0813, -0.0677, 0.0607],
|
||||||
|
[ 0.0656, 0.0851, 0.0808],
|
||||||
|
[ 0.0264, 0.0463, 0.0912],
|
||||||
|
[ 0.0295, 0.0326, 0.0590],
|
||||||
|
[-0.0244, -0.0270, 0.0025],
|
||||||
|
[ 0.0443, -0.0102, 0.0288],
|
||||||
|
[-0.0465, -0.0090, -0.0205],
|
||||||
|
[ 0.0359, 0.0236, 0.0082],
|
||||||
|
[-0.0776, 0.0854, 0.1048],
|
||||||
|
[ 0.0564, 0.0264, 0.0561],
|
||||||
|
[ 0.0006, 0.0594, 0.0418],
|
||||||
|
[-0.0319, -0.0542, -0.0637],
|
||||||
|
[-0.0268, 0.0024, 0.0260],
|
||||||
|
[ 0.0539, 0.0265, 0.0358],
|
||||||
|
[-0.0359, -0.0312, -0.0287],
|
||||||
|
[-0.0285, -0.1032, -0.1237],
|
||||||
|
[ 0.1041, 0.0537, 0.0622],
|
||||||
|
[-0.0086, -0.0374, -0.0051],
|
||||||
|
[ 0.0390, 0.0670, 0.2863],
|
||||||
|
[ 0.0069, 0.0144, 0.0082],
|
||||||
|
[ 0.0006, -0.0167, 0.0079],
|
||||||
|
[ 0.0313, -0.0574, -0.0232],
|
||||||
|
[-0.1454, -0.0902, -0.0481],
|
||||||
|
[ 0.0714, 0.0827, 0.0447],
|
||||||
|
[-0.0304, -0.0574, -0.0196],
|
||||||
|
[ 0.0401, 0.0384, 0.0204],
|
||||||
|
[-0.0758, -0.0297, -0.0014],
|
||||||
|
[ 0.0568, 0.1307, 0.1372],
|
||||||
|
[-0.0055, -0.0310, -0.0380],
|
||||||
|
[ 0.0239, -0.0305, 0.0325],
|
||||||
|
[-0.0663, -0.0673, -0.0140],
|
||||||
|
[-0.0416, -0.0047, -0.0023],
|
||||||
|
[ 0.0166, 0.0112, -0.0093],
|
||||||
|
[-0.0211, 0.0011, 0.0331],
|
||||||
|
[ 0.1833, 0.1466, 0.2250],
|
||||||
|
[-0.0368, 0.0370, 0.0295],
|
||||||
|
[-0.3441, -0.3543, -0.2008],
|
||||||
|
[-0.0479, -0.0489, -0.0420],
|
||||||
|
[-0.0660, -0.0153, 0.0800],
|
||||||
|
[-0.0101, 0.0068, 0.0156],
|
||||||
|
[-0.0690, -0.0452, -0.0927],
|
||||||
|
[-0.0145, 0.0041, 0.0015],
|
||||||
|
[ 0.0421, 0.0451, 0.0373],
|
||||||
|
[ 0.0504, -0.0483, -0.0356],
|
||||||
|
[-0.0837, 0.0168, 0.0055]
|
||||||
|
]
|
||||||
|
|
||||||
|
latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.scale_factor = 1.0
|
||||||
|
self.latents_mean = torch.tensor([
|
||||||
|
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||||
|
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||||
|
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
|
||||||
|
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
|
||||||
|
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
|
||||||
|
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
|
||||||
|
]).view(1, self.latent_channels, 1, 1, 1)
|
||||||
|
self.latents_std = torch.tensor([
|
||||||
|
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
|
||||||
|
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
|
||||||
|
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
|
||||||
|
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
|
||||||
|
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
|
||||||
|
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
|
||||||
|
]).view(1, self.latent_channels, 1, 1, 1)
|
||||||
|
|
||||||
class Hunyuan3Dv2(LatentFormat):
|
class Hunyuan3Dv2(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
|
@ -201,8 +201,10 @@ class WanAttentionBlock(nn.Module):
|
|||||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||||
"""
|
"""
|
||||||
# assert e.dtype == torch.float32
|
# assert e.dtype == torch.float32
|
||||||
|
if e.ndim < 4:
|
||||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||||
|
else:
|
||||||
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
|
||||||
# assert e[0].dtype == torch.float32
|
# assert e[0].dtype == torch.float32
|
||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
@ -325,7 +327,10 @@ class Head(nn.Module):
|
|||||||
e(Tensor): Shape [B, C]
|
e(Tensor): Shape [B, C]
|
||||||
"""
|
"""
|
||||||
# assert e.dtype == torch.float32
|
# assert e.dtype == torch.float32
|
||||||
|
if e.ndim < 3:
|
||||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
||||||
|
else:
|
||||||
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
|
||||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -506,8 +511,9 @@ class WanModel(torch.nn.Module):
|
|||||||
|
|
||||||
# time embeddings
|
# time embeddings
|
||||||
e = self.time_embedding(
|
e = self.time_embedding(
|
||||||
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
|
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||||
|
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||||
|
|
||||||
# context
|
# context
|
||||||
context = self.text_embedding(context)
|
context = self.text_embedding(context)
|
||||||
|
726
comfy/ldm/wan/vae2_2.py
Normal file
726
comfy/ldm/wan/vae2_2.py
Normal file
@ -0,0 +1,726 @@
|
|||||||
|
# original version: https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/vae2_2.py
|
||||||
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
from .vae import AttentionBlock, CausalConv3d, RMS_norm
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
CACHE_T = 2
|
||||||
|
|
||||||
|
|
||||||
|
class Resample(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, mode):
|
||||||
|
assert mode in (
|
||||||
|
"none",
|
||||||
|
"upsample2d",
|
||||||
|
"upsample3d",
|
||||||
|
"downsample2d",
|
||||||
|
"downsample3d",
|
||||||
|
)
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
# layers
|
||||||
|
if mode == "upsample2d":
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
||||||
|
ops.Conv2d(dim, dim, 3, padding=1),
|
||||||
|
)
|
||||||
|
elif mode == "upsample3d":
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
||||||
|
ops.Conv2d(dim, dim, 3, padding=1),
|
||||||
|
# ops.Conv2d(dim, dim//2, 3, padding=1)
|
||||||
|
)
|
||||||
|
self.time_conv = CausalConv3d(
|
||||||
|
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||||
|
elif mode == "downsample2d":
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||||
|
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||||
|
elif mode == "downsample3d":
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||||
|
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||||
|
self.time_conv = CausalConv3d(
|
||||||
|
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||||
|
else:
|
||||||
|
self.resample = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
b, c, t, h, w = x.size()
|
||||||
|
if self.mode == "upsample3d":
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
if feat_cache[idx] is None:
|
||||||
|
feat_cache[idx] = "Rep"
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
|
||||||
|
feat_cache[idx] != "Rep"):
|
||||||
|
# cache last frame of last two chunk
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device),
|
||||||
|
cache_x,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
|
||||||
|
feat_cache[idx] == "Rep"):
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros_like(cache_x).to(cache_x.device),
|
||||||
|
cache_x
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
if feat_cache[idx] == "Rep":
|
||||||
|
x = self.time_conv(x)
|
||||||
|
else:
|
||||||
|
x = self.time_conv(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
x = x.reshape(b, 2, c, t, h, w)
|
||||||
|
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||||||
|
3)
|
||||||
|
x = x.reshape(b, c, t * 2, h, w)
|
||||||
|
t = x.shape[2]
|
||||||
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
|
x = self.resample(x)
|
||||||
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||||
|
|
||||||
|
if self.mode == "downsample3d":
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
if feat_cache[idx] is None:
|
||||||
|
feat_cache[idx] = x.clone()
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
cache_x = x[:, :, -1:, :, :].clone()
|
||||||
|
x = self.time_conv(
|
||||||
|
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.residual = nn.Sequential(
|
||||||
|
RMS_norm(in_dim, images=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||||
|
RMS_norm(out_dim, images=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
CausalConv3d(out_dim, out_dim, 3, padding=1),
|
||||||
|
)
|
||||||
|
self.shortcut = (
|
||||||
|
CausalConv3d(in_dim, out_dim, 1)
|
||||||
|
if in_dim != out_dim else nn.Identity())
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
h = self.shortcut(x)
|
||||||
|
for layer in self.residual:
|
||||||
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
# cache last frame of last two chunk
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device),
|
||||||
|
cache_x,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
x = layer(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
|
def patchify(x, patch_size):
|
||||||
|
if patch_size == 1:
|
||||||
|
return x
|
||||||
|
if x.dim() == 4:
|
||||||
|
x = rearrange(
|
||||||
|
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
||||||
|
elif x.dim() == 5:
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b c f (h q) (w r) -> b (c r q) f h w",
|
||||||
|
q=patch_size,
|
||||||
|
r=patch_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def unpatchify(x, patch_size):
|
||||||
|
if patch_size == 1:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if x.dim() == 4:
|
||||||
|
x = rearrange(
|
||||||
|
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
||||||
|
elif x.dim() == 5:
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b (c r q) f h w -> b c f (h q) (w r)",
|
||||||
|
q=patch_size,
|
||||||
|
r=patch_size,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AvgDown3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
factor_t,
|
||||||
|
factor_s=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.factor_t = factor_t
|
||||||
|
self.factor_s = factor_s
|
||||||
|
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||||
|
|
||||||
|
assert in_channels * self.factor % out_channels == 0
|
||||||
|
self.group_size = in_channels * self.factor // out_channels
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
||||||
|
pad = (0, 0, 0, 0, pad_t, 0)
|
||||||
|
x = F.pad(x, pad)
|
||||||
|
B, C, T, H, W = x.shape
|
||||||
|
x = x.view(
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
T // self.factor_t,
|
||||||
|
self.factor_t,
|
||||||
|
H // self.factor_s,
|
||||||
|
self.factor_s,
|
||||||
|
W // self.factor_s,
|
||||||
|
self.factor_s,
|
||||||
|
)
|
||||||
|
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
||||||
|
x = x.view(
|
||||||
|
B,
|
||||||
|
C * self.factor,
|
||||||
|
T // self.factor_t,
|
||||||
|
H // self.factor_s,
|
||||||
|
W // self.factor_s,
|
||||||
|
)
|
||||||
|
x = x.view(
|
||||||
|
B,
|
||||||
|
self.out_channels,
|
||||||
|
self.group_size,
|
||||||
|
T // self.factor_t,
|
||||||
|
H // self.factor_s,
|
||||||
|
W // self.factor_s,
|
||||||
|
)
|
||||||
|
x = x.mean(dim=2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DupUp3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
factor_t,
|
||||||
|
factor_s=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.factor_t = factor_t
|
||||||
|
self.factor_s = factor_s
|
||||||
|
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||||
|
|
||||||
|
assert out_channels * self.factor % in_channels == 0
|
||||||
|
self.repeats = out_channels * self.factor // in_channels
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
||||||
|
x = x.repeat_interleave(self.repeats, dim=1)
|
||||||
|
x = x.view(
|
||||||
|
x.size(0),
|
||||||
|
self.out_channels,
|
||||||
|
self.factor_t,
|
||||||
|
self.factor_s,
|
||||||
|
self.factor_s,
|
||||||
|
x.size(2),
|
||||||
|
x.size(3),
|
||||||
|
x.size(4),
|
||||||
|
)
|
||||||
|
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
||||||
|
x = x.view(
|
||||||
|
x.size(0),
|
||||||
|
self.out_channels,
|
||||||
|
x.size(2) * self.factor_t,
|
||||||
|
x.size(4) * self.factor_s,
|
||||||
|
x.size(6) * self.factor_s,
|
||||||
|
)
|
||||||
|
if first_chunk:
|
||||||
|
x = x[:, :, self.factor_t - 1:, :, :]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Down_ResidualBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
dropout,
|
||||||
|
mult,
|
||||||
|
temperal_downsample=False,
|
||||||
|
down_flag=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Shortcut path with downsample
|
||||||
|
self.avg_shortcut = AvgDown3D(
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
factor_t=2 if temperal_downsample else 1,
|
||||||
|
factor_s=2 if down_flag else 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Main path with residual blocks and downsample
|
||||||
|
downsamples = []
|
||||||
|
for _ in range(mult):
|
||||||
|
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||||
|
in_dim = out_dim
|
||||||
|
|
||||||
|
# Add the final downsample block
|
||||||
|
if down_flag:
|
||||||
|
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
||||||
|
downsamples.append(Resample(out_dim, mode=mode))
|
||||||
|
|
||||||
|
self.downsamples = nn.Sequential(*downsamples)
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
x_copy = x.clone()
|
||||||
|
for module in self.downsamples:
|
||||||
|
x = module(x, feat_cache, feat_idx)
|
||||||
|
|
||||||
|
return x + self.avg_shortcut(x_copy)
|
||||||
|
|
||||||
|
|
||||||
|
class Up_ResidualBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
dropout,
|
||||||
|
mult,
|
||||||
|
temperal_upsample=False,
|
||||||
|
up_flag=False):
|
||||||
|
super().__init__()
|
||||||
|
# Shortcut path with upsample
|
||||||
|
if up_flag:
|
||||||
|
self.avg_shortcut = DupUp3D(
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
factor_t=2 if temperal_upsample else 1,
|
||||||
|
factor_s=2 if up_flag else 1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.avg_shortcut = None
|
||||||
|
|
||||||
|
# Main path with residual blocks and upsample
|
||||||
|
upsamples = []
|
||||||
|
for _ in range(mult):
|
||||||
|
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||||
|
in_dim = out_dim
|
||||||
|
|
||||||
|
# Add the final upsample block
|
||||||
|
if up_flag:
|
||||||
|
mode = "upsample3d" if temperal_upsample else "upsample2d"
|
||||||
|
upsamples.append(Resample(out_dim, mode=mode))
|
||||||
|
|
||||||
|
self.upsamples = nn.Sequential(*upsamples)
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||||
|
x_main = x.clone()
|
||||||
|
for module in self.upsamples:
|
||||||
|
x_main = module(x_main, feat_cache, feat_idx)
|
||||||
|
if self.avg_shortcut is not None:
|
||||||
|
x_shortcut = self.avg_shortcut(x, first_chunk)
|
||||||
|
return x_main + x_shortcut
|
||||||
|
else:
|
||||||
|
return x_main
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder3d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=128,
|
||||||
|
z_dim=4,
|
||||||
|
dim_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_scales=[],
|
||||||
|
temperal_downsample=[True, True, False],
|
||||||
|
dropout=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.z_dim = z_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_scales = attn_scales
|
||||||
|
self.temperal_downsample = temperal_downsample
|
||||||
|
|
||||||
|
# dimensions
|
||||||
|
dims = [dim * u for u in [1] + dim_mult]
|
||||||
|
scale = 1.0
|
||||||
|
|
||||||
|
# init block
|
||||||
|
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
|
||||||
|
|
||||||
|
# downsample blocks
|
||||||
|
downsamples = []
|
||||||
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
|
t_down_flag = (
|
||||||
|
temperal_downsample[i]
|
||||||
|
if i < len(temperal_downsample) else False)
|
||||||
|
downsamples.append(
|
||||||
|
Down_ResidualBlock(
|
||||||
|
in_dim=in_dim,
|
||||||
|
out_dim=out_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
mult=num_res_blocks,
|
||||||
|
temperal_downsample=t_down_flag,
|
||||||
|
down_flag=i != len(dim_mult) - 1,
|
||||||
|
))
|
||||||
|
scale /= 2.0
|
||||||
|
self.downsamples = nn.Sequential(*downsamples)
|
||||||
|
|
||||||
|
# middle blocks
|
||||||
|
self.middle = nn.Sequential(
|
||||||
|
ResidualBlock(out_dim, out_dim, dropout),
|
||||||
|
AttentionBlock(out_dim),
|
||||||
|
ResidualBlock(out_dim, out_dim, dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
# # output blocks
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
RMS_norm(out_dim, images=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
CausalConv3d(out_dim, z_dim, 3, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device),
|
||||||
|
cache_x,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
x = self.conv1(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = self.conv1(x)
|
||||||
|
|
||||||
|
## downsamples
|
||||||
|
for layer in self.downsamples:
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## middle
|
||||||
|
for layer in self.middle:
|
||||||
|
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## head
|
||||||
|
for layer in self.head:
|
||||||
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device),
|
||||||
|
cache_x,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
x = layer(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder3d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=128,
|
||||||
|
z_dim=4,
|
||||||
|
dim_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_scales=[],
|
||||||
|
temperal_upsample=[False, True, True],
|
||||||
|
dropout=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.z_dim = z_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_scales = attn_scales
|
||||||
|
self.temperal_upsample = temperal_upsample
|
||||||
|
|
||||||
|
# dimensions
|
||||||
|
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||||
|
# init block
|
||||||
|
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||||
|
|
||||||
|
# middle blocks
|
||||||
|
self.middle = nn.Sequential(
|
||||||
|
ResidualBlock(dims[0], dims[0], dropout),
|
||||||
|
AttentionBlock(dims[0]),
|
||||||
|
ResidualBlock(dims[0], dims[0], dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
# upsample blocks
|
||||||
|
upsamples = []
|
||||||
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
|
t_up_flag = temperal_upsample[i] if i < len(
|
||||||
|
temperal_upsample) else False
|
||||||
|
upsamples.append(
|
||||||
|
Up_ResidualBlock(
|
||||||
|
in_dim=in_dim,
|
||||||
|
out_dim=out_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
mult=num_res_blocks + 1,
|
||||||
|
temperal_upsample=t_up_flag,
|
||||||
|
up_flag=i != len(dim_mult) - 1,
|
||||||
|
))
|
||||||
|
self.upsamples = nn.Sequential(*upsamples)
|
||||||
|
|
||||||
|
# output blocks
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
RMS_norm(out_dim, images=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
CausalConv3d(out_dim, 12, 3, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device),
|
||||||
|
cache_x,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
x = self.conv1(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = self.conv1(x)
|
||||||
|
|
||||||
|
for layer in self.middle:
|
||||||
|
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## upsamples
|
||||||
|
for layer in self.upsamples:
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx, first_chunk)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## head
|
||||||
|
for layer in self.head:
|
||||||
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
cache_x = torch.cat(
|
||||||
|
[
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device),
|
||||||
|
cache_x,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
x = layer(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def count_conv3d(model):
|
||||||
|
count = 0
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, CausalConv3d):
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
class WanVAE(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=160,
|
||||||
|
dec_dim=256,
|
||||||
|
z_dim=16,
|
||||||
|
dim_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_scales=[],
|
||||||
|
temperal_downsample=[True, True, False],
|
||||||
|
dropout=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.z_dim = z_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_scales = attn_scales
|
||||||
|
self.temperal_downsample = temperal_downsample
|
||||||
|
self.temperal_upsample = temperal_downsample[::-1]
|
||||||
|
|
||||||
|
# modules
|
||||||
|
self.encoder = Encoder3d(
|
||||||
|
dim,
|
||||||
|
z_dim * 2,
|
||||||
|
dim_mult,
|
||||||
|
num_res_blocks,
|
||||||
|
attn_scales,
|
||||||
|
self.temperal_downsample,
|
||||||
|
dropout,
|
||||||
|
)
|
||||||
|
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||||
|
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||||
|
self.decoder = Decoder3d(
|
||||||
|
dec_dim,
|
||||||
|
z_dim,
|
||||||
|
dim_mult,
|
||||||
|
num_res_blocks,
|
||||||
|
attn_scales,
|
||||||
|
self.temperal_upsample,
|
||||||
|
dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
self.clear_cache()
|
||||||
|
x = patchify(x, patch_size=2)
|
||||||
|
t = x.shape[2]
|
||||||
|
iter_ = 1 + (t - 1) // 4
|
||||||
|
for i in range(iter_):
|
||||||
|
self._enc_conv_idx = [0]
|
||||||
|
if i == 0:
|
||||||
|
out = self.encoder(
|
||||||
|
x[:, :, :1, :, :],
|
||||||
|
feat_cache=self._enc_feat_map,
|
||||||
|
feat_idx=self._enc_conv_idx,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out_ = self.encoder(
|
||||||
|
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||||
|
feat_cache=self._enc_feat_map,
|
||||||
|
feat_idx=self._enc_conv_idx,
|
||||||
|
)
|
||||||
|
out = torch.cat([out, out_], 2)
|
||||||
|
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||||
|
self.clear_cache()
|
||||||
|
return mu
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
self.clear_cache()
|
||||||
|
iter_ = z.shape[2]
|
||||||
|
x = self.conv2(z)
|
||||||
|
for i in range(iter_):
|
||||||
|
self._conv_idx = [0]
|
||||||
|
if i == 0:
|
||||||
|
out = self.decoder(
|
||||||
|
x[:, :, i:i + 1, :, :],
|
||||||
|
feat_cache=self._feat_map,
|
||||||
|
feat_idx=self._conv_idx,
|
||||||
|
first_chunk=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out_ = self.decoder(
|
||||||
|
x[:, :, i:i + 1, :, :],
|
||||||
|
feat_cache=self._feat_map,
|
||||||
|
feat_idx=self._conv_idx,
|
||||||
|
)
|
||||||
|
out = torch.cat([out, out_], 2)
|
||||||
|
out = unpatchify(out, patch_size=2)
|
||||||
|
self.clear_cache()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def reparameterize(self, mu, log_var):
|
||||||
|
std = torch.exp(0.5 * log_var)
|
||||||
|
eps = torch.randn_like(std)
|
||||||
|
return eps * std + mu
|
||||||
|
|
||||||
|
def sample(self, imgs, deterministic=False):
|
||||||
|
mu, log_var = self.encode(imgs)
|
||||||
|
if deterministic:
|
||||||
|
return mu
|
||||||
|
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||||
|
return mu + std * torch.randn_like(std)
|
||||||
|
|
||||||
|
def clear_cache(self):
|
||||||
|
self._conv_num = count_conv3d(self.decoder)
|
||||||
|
self._conv_idx = [0]
|
||||||
|
self._feat_map = [None] * self._conv_num
|
||||||
|
# cache encode
|
||||||
|
self._enc_conv_num = count_conv3d(self.encoder)
|
||||||
|
self._enc_conv_idx = [0]
|
||||||
|
self._enc_feat_map = [None] * self._enc_conv_num
|
@ -1097,6 +1097,7 @@ class WAN21(BaseModel):
|
|||||||
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
|
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
|
||||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
|
|
||||||
|
if extra_channels != image.shape[1] + 4:
|
||||||
if not self.image_to_video or extra_channels == image.shape[1]:
|
if not self.image_to_video or extra_channels == image.shape[1]:
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@ -1182,6 +1183,31 @@ class WAN21_Camera(WAN21):
|
|||||||
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN22(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
|
if denoise_mask is not None:
|
||||||
|
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||||
|
if denoise_mask is None:
|
||||||
|
return timestep
|
||||||
|
temp_ts = (torch.mean(denoise_mask[:, :, :, ::2, ::2], dim=1, keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
|
||||||
|
return temp_ts
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
|
return latent_image
|
||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||||
|
@ -346,7 +346,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "wan2.1"
|
dit_config["image_model"] = "wan2.1"
|
||||||
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
|
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
|
||||||
|
out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4
|
||||||
dit_config["dim"] = dim
|
dit_config["dim"] = dim
|
||||||
|
dit_config["out_dim"] = out_dim
|
||||||
dit_config["num_heads"] = dim // 128
|
dit_config["num_heads"] = dim // 128
|
||||||
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
|
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
|
||||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||||
|
14
comfy/sd.py
14
comfy/sd.py
@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model
|
|||||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||||
import comfy.ldm.cosmos.vae
|
import comfy.ldm.cosmos.vae
|
||||||
import comfy.ldm.wan.vae
|
import comfy.ldm.wan.vae
|
||||||
|
import comfy.ldm.wan.vae2_2
|
||||||
import comfy.ldm.hunyuan3d.vae
|
import comfy.ldm.hunyuan3d.vae
|
||||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||||
import yaml
|
import yaml
|
||||||
@ -420,6 +421,19 @@ class VAE:
|
|||||||
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
elif "decoder.middle.0.residual.0.gamma" in sd:
|
elif "decoder.middle.0.residual.0.gamma" in sd:
|
||||||
|
if "decoder.upsamples.0.upsamples.0.residual.2.weight" in sd: # Wan 2.2 VAE
|
||||||
|
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||||
|
self.upscale_index_formula = (4, 16, 16)
|
||||||
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||||
|
self.downscale_index_formula = (4, 16, 16)
|
||||||
|
self.latent_dim = 3
|
||||||
|
self.latent_channels = 48
|
||||||
|
ddconfig = {"dim": 160, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||||
|
self.first_stage_model = comfy.ldm.wan.vae2_2.WanVAE(**ddconfig)
|
||||||
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
|
||||||
|
else: # Wan 2.1 VAE
|
||||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||||
self.upscale_index_formula = (4, 8, 8)
|
self.upscale_index_formula = (4, 8, 8)
|
||||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||||
|
@ -1059,6 +1059,19 @@ class WAN21_Vace(WAN21_T2V):
|
|||||||
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN22_T2V(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "t2v",
|
||||||
|
"out_dim": 48,
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.Wan22
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN22(self, image_to_video=True, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@ -1217,6 +1230,6 @@ class Omnigen2(supported_models_base.BASE):
|
|||||||
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
import io
|
import io
|
||||||
|
import av
|
||||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
|
||||||
class VideoInput(ABC):
|
class VideoInput(ABC):
|
||||||
@ -70,3 +71,15 @@ class VideoInput(ABC):
|
|||||||
components = self.get_components()
|
components = self.get_components()
|
||||||
frame_count = components.images.shape[0]
|
frame_count = components.images.shape[0]
|
||||||
return float(frame_count / components.frame_rate)
|
return float(frame_count / components.frame_rate)
|
||||||
|
|
||||||
|
def get_container_format(self) -> str:
|
||||||
|
"""
|
||||||
|
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Container format as string
|
||||||
|
"""
|
||||||
|
# Default implementation - subclasses should override for better performance
|
||||||
|
source = self.get_stream_source()
|
||||||
|
with av.open(source, mode="r") as container:
|
||||||
|
return container.format.name
|
||||||
|
@ -121,6 +121,18 @@ class VideoFromFile(VideoInput):
|
|||||||
|
|
||||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||||
|
|
||||||
|
def get_container_format(self) -> str:
|
||||||
|
"""
|
||||||
|
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Container format as string
|
||||||
|
"""
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0)
|
||||||
|
with av.open(self.__file, mode='r') as container:
|
||||||
|
return container.format.name
|
||||||
|
|
||||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||||
# Get video frames
|
# Get video frames
|
||||||
frames = []
|
frames = []
|
||||||
|
@ -5,7 +5,6 @@ import torch
|
|||||||
from comfy_api_nodes.util.validation_utils import (
|
from comfy_api_nodes.util.validation_utils import (
|
||||||
get_image_dimensions,
|
get_image_dimensions,
|
||||||
validate_image_dimensions,
|
validate_image_dimensions,
|
||||||
validate_video_dimensions,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -176,54 +175,76 @@ def validate_input_image(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def validate_input_video(
|
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||||
video: VideoInput, num_frames_out: int, with_frame_conditioning: bool = False
|
"""
|
||||||
):
|
Validates and processes video input for Moonvalley Video-to-Video generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video: Input video to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated and potentially trimmed video
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If video doesn't meet requirements
|
||||||
|
MoonvalleyApiError: If video duration is too short
|
||||||
|
"""
|
||||||
|
width, height = _get_video_dimensions(video)
|
||||||
|
_validate_video_dimensions(width, height)
|
||||||
|
_validate_container_format(video)
|
||||||
|
|
||||||
|
return _validate_and_trim_duration(video)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
|
||||||
|
"""Extracts video dimensions with error handling."""
|
||||||
try:
|
try:
|
||||||
width, height = video.get_dimensions()
|
return video.get_dimensions()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Error getting dimensions of video: %s", e)
|
logging.error("Error getting dimensions of video: %s", e)
|
||||||
raise ValueError(f"Cannot get video dimensions: {e}") from e
|
raise ValueError(f"Cannot get video dimensions: {e}") from e
|
||||||
|
|
||||||
validate_input_media(width, height, with_frame_conditioning)
|
|
||||||
validate_video_dimensions(
|
|
||||||
video,
|
|
||||||
min_width=MIN_VID_WIDTH,
|
|
||||||
min_height=MIN_VID_HEIGHT,
|
|
||||||
max_width=MAX_VID_WIDTH,
|
|
||||||
max_height=MAX_VID_HEIGHT,
|
|
||||||
)
|
|
||||||
|
|
||||||
trimmed_video = validate_input_video_length(video, num_frames_out)
|
def _validate_video_dimensions(width: int, height: int) -> None:
|
||||||
return trimmed_video
|
"""Validates video dimensions meet Moonvalley V2V requirements."""
|
||||||
|
supported_resolutions = {
|
||||||
|
(1920, 1080), (1080, 1920), (1152, 1152),
|
||||||
|
(1536, 1152), (1152, 1536)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (width, height) not in supported_resolutions:
|
||||||
|
supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)])
|
||||||
|
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||||
|
|
||||||
|
|
||||||
def validate_input_video_length(video: VideoInput, num_frames: int):
|
def _validate_container_format(video: VideoInput) -> None:
|
||||||
|
"""Validates video container format is MP4."""
|
||||||
|
container_format = video.get_container_format()
|
||||||
|
if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']:
|
||||||
|
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
||||||
|
|
||||||
if video.get_duration() > 60:
|
|
||||||
raise MoonvalleyApiError(
|
|
||||||
"Input Video lenth should be less than 1min. Please trim."
|
|
||||||
)
|
|
||||||
|
|
||||||
if num_frames == 128:
|
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||||
if video.get_duration() < 5:
|
"""Validates video duration and trims to 5 seconds if needed."""
|
||||||
raise MoonvalleyApiError(
|
duration = video.get_duration()
|
||||||
"Input Video length is less than 5s. Please use a video longer than or equal to 5s."
|
_validate_minimum_duration(duration)
|
||||||
)
|
return _trim_if_too_long(video, duration)
|
||||||
if video.get_duration() > 5:
|
|
||||||
# trim video to 5s
|
|
||||||
video = trim_video(video, 5)
|
def _validate_minimum_duration(duration: float) -> None:
|
||||||
if num_frames == 256:
|
"""Ensures video is at least 5 seconds long."""
|
||||||
if video.get_duration() < 10:
|
if duration < 5:
|
||||||
raise MoonvalleyApiError(
|
raise MoonvalleyApiError("Input video must be at least 5 seconds long.")
|
||||||
"Input Video length is less than 10s. Please use a video longer than or equal to 10s."
|
|
||||||
)
|
|
||||||
if video.get_duration() > 10:
|
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||||
# trim video to 10s
|
"""Trims video to 5 seconds if longer."""
|
||||||
video = trim_video(video, 10)
|
if duration > 5:
|
||||||
|
return trim_video(video, 5)
|
||||||
return video
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||||
"""
|
"""
|
||||||
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
||||||
@ -278,15 +299,13 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
|||||||
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
|
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate target frame count that's divisible by 32
|
# Calculate target frame count that's divisible by 16
|
||||||
fps = input_container.streams.video[0].average_rate
|
fps = input_container.streams.video[0].average_rate
|
||||||
estimated_frames = int(duration_sec * fps)
|
estimated_frames = int(duration_sec * fps)
|
||||||
target_frames = (
|
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
|
||||||
estimated_frames // 32
|
|
||||||
) * 32 # Round down to nearest multiple of 32
|
|
||||||
|
|
||||||
if target_frames == 0:
|
if target_frames == 0:
|
||||||
raise ValueError("Video too short: need at least 32 frames for Moonvalley")
|
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
||||||
|
|
||||||
frame_count = 0
|
frame_count = 0
|
||||||
audio_frame_count = 0
|
audio_frame_count = 0
|
||||||
@ -353,8 +372,8 @@ class BaseMoonvalleyVideoNode:
|
|||||||
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
|
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
|
||||||
"9:16 (1080 x 1920)": {"width": 1080, "height": 1920},
|
"9:16 (1080 x 1920)": {"width": 1080, "height": 1920},
|
||||||
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
|
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
|
||||||
"4:3 (1440 x 1080)": {"width": 1440, "height": 1080},
|
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
|
||||||
"3:4 (1080 x 1440)": {"width": 1080, "height": 1440},
|
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
|
||||||
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
|
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
|
||||||
}
|
}
|
||||||
if resolution in res_map:
|
if resolution in res_map:
|
||||||
@ -494,7 +513,6 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
image = kwargs.get("image", None)
|
image = kwargs.get("image", None)
|
||||||
if image is None:
|
if image is None:
|
||||||
raise MoonvalleyApiError("image is required")
|
raise MoonvalleyApiError("image is required")
|
||||||
total_frames = get_total_frames_from_length()
|
|
||||||
|
|
||||||
validate_input_image(image, True)
|
validate_input_image(image, True)
|
||||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||||
@ -505,7 +523,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
steps=kwargs.get("steps"),
|
steps=kwargs.get("steps"),
|
||||||
seed=kwargs.get("seed"),
|
seed=kwargs.get("seed"),
|
||||||
guidance_scale=kwargs.get("prompt_adherence"),
|
guidance_scale=kwargs.get("prompt_adherence"),
|
||||||
num_frames=total_frames,
|
num_frames=128,
|
||||||
width=width_height.get("width"),
|
width=width_height.get("width"),
|
||||||
height=width_height.get("height"),
|
height=width_height.get("height"),
|
||||||
use_negative_prompts=True,
|
use_negative_prompts=True,
|
||||||
@ -549,21 +567,28 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
input_types = super().INPUT_TYPES()
|
return {
|
||||||
for param in ["resolution", "image"]:
|
"required": {
|
||||||
if param in input_types["required"]:
|
"prompt": model_field_to_node_input(
|
||||||
del input_types["required"][param]
|
IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text",
|
||||||
if param in input_types["optional"]:
|
multiline=True
|
||||||
del input_types["optional"][param]
|
|
||||||
input_types["optional"] = {
|
|
||||||
"video": (
|
|
||||||
IO.VIDEO,
|
|
||||||
{
|
|
||||||
"default": "",
|
|
||||||
"multiline": False,
|
|
||||||
"tooltip": "The reference video used to generate the output video. Input a 5s video for 128 frames and a 10s video for 256 frames. Longer videos will be trimmed automatically.",
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
|
"negative_prompt": model_field_to_node_input(
|
||||||
|
IO.STRING,
|
||||||
|
MoonvalleyVideoToVideoInferenceParams,
|
||||||
|
"negative_prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts"
|
||||||
|
),
|
||||||
|
"seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}),
|
||||||
"control_type": (
|
"control_type": (
|
||||||
["Motion Transfer", "Pose Transfer"],
|
["Motion Transfer", "Pose Transfer"],
|
||||||
{"default": "Motion Transfer"},
|
{"default": "Motion Transfer"},
|
||||||
@ -577,10 +602,9 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
"max": 100,
|
"max": 100,
|
||||||
"tooltip": "Only used if control_type is 'Motion Transfer'",
|
"tooltip": "Only used if control_type is 'Motion Transfer'",
|
||||||
},
|
},
|
||||||
),
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return input_types
|
|
||||||
|
|
||||||
RETURN_TYPES = ("VIDEO",)
|
RETURN_TYPES = ("VIDEO",)
|
||||||
RETURN_NAMES = ("video",)
|
RETURN_NAMES = ("video",)
|
||||||
@ -589,15 +613,13 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||||
):
|
):
|
||||||
video = kwargs.get("video")
|
video = kwargs.get("video")
|
||||||
num_frames = get_total_frames_from_length()
|
|
||||||
|
|
||||||
if not video:
|
if not video:
|
||||||
raise MoonvalleyApiError("video is required")
|
raise MoonvalleyApiError("video is required")
|
||||||
|
|
||||||
"""Validate video input"""
|
|
||||||
video_url = ""
|
video_url = ""
|
||||||
if video:
|
if video:
|
||||||
validated_video = validate_input_video(video, num_frames, False)
|
validated_video = validate_video_to_video_input(video)
|
||||||
video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
|
video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
|
||||||
|
|
||||||
control_type = kwargs.get("control_type")
|
control_type = kwargs.get("control_type")
|
||||||
@ -605,12 +627,16 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
|
|
||||||
"""Validate prompts and inference input"""
|
"""Validate prompts and inference input"""
|
||||||
validate_prompts(prompt, negative_prompt)
|
validate_prompts(prompt, negative_prompt)
|
||||||
inference_params = MoonvalleyVideoToVideoInferenceParams(
|
|
||||||
|
# Only include motion_intensity for Motion Transfer
|
||||||
|
control_params = {}
|
||||||
|
if control_type == "Motion Transfer" and motion_intensity is not None:
|
||||||
|
control_params['motion_intensity'] = motion_intensity
|
||||||
|
|
||||||
|
inference_params=MoonvalleyVideoToVideoInferenceParams(
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
steps=kwargs.get("steps"),
|
|
||||||
seed=kwargs.get("seed"),
|
seed=kwargs.get("seed"),
|
||||||
guidance_scale=kwargs.get("prompt_adherence"),
|
control_params=control_params
|
||||||
control_params={"motion_intensity": motion_intensity},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
control = self.parseControlParameter(control_type)
|
control = self.parseControlParameter(control_type)
|
||||||
@ -667,14 +693,13 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
):
|
):
|
||||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||||
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
|
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
|
||||||
num_frames = get_total_frames_from_length()
|
|
||||||
|
|
||||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
inference_params=MoonvalleyTextToVideoInferenceParams(
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
steps=kwargs.get("steps"),
|
steps=kwargs.get("steps"),
|
||||||
seed=kwargs.get("seed"),
|
seed=kwargs.get("seed"),
|
||||||
guidance_scale=kwargs.get("prompt_adherence"),
|
guidance_scale=kwargs.get("prompt_adherence"),
|
||||||
num_frames=num_frames,
|
num_frames=128,
|
||||||
width=width_height.get("width"),
|
width=width_height.get("width"),
|
||||||
height=width_height.get("height"),
|
height=width_height.get("height"),
|
||||||
)
|
)
|
||||||
@ -707,22 +732,12 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode,
|
"MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode,
|
||||||
"MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode,
|
"MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode,
|
||||||
# "MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode,
|
"MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video",
|
"MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video",
|
||||||
"MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video",
|
"MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video",
|
||||||
# "MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video",
|
"MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_total_frames_from_length(length="5s"):
|
|
||||||
# if length == '5s':
|
|
||||||
# return 128
|
|
||||||
# elif length == '10s':
|
|
||||||
# return 256
|
|
||||||
return 128
|
|
||||||
# else:
|
|
||||||
# raise MoonvalleyApiError("length is required")
|
|
||||||
|
@ -685,6 +685,49 @@ class WanTrackToVideo:
|
|||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (positive, negative, out_latent)
|
return (positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
|
class Wan22ImageToVideoLatent:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"vae": ("VAE", ),
|
||||||
|
"width": ("INT", {"default": 1280, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||||
|
"height": ("INT", {"default": 704, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||||
|
"length": ("INT", {"default": 49, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
},
|
||||||
|
"optional": {"start_image": ("IMAGE", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/inpaint"
|
||||||
|
|
||||||
|
def encode(self, vae, width, height, length, batch_size, start_image=None):
|
||||||
|
latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
if start_image is None:
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return (out_latent,)
|
||||||
|
|
||||||
|
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
if start_image is not None:
|
||||||
|
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
latent_temp = vae.encode(start_image)
|
||||||
|
latent[:, :, :latent_temp.shape[-3]] = latent_temp
|
||||||
|
mask[:, :, :latent_temp.shape[-3]] *= 0.0
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
latent_format = comfy.latent_formats.Wan22()
|
||||||
|
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
|
||||||
|
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
||||||
|
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||||
|
return (out_latent,)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"WanTrackToVideo": WanTrackToVideo,
|
"WanTrackToVideo": WanTrackToVideo,
|
||||||
"WanImageToVideo": WanImageToVideo,
|
"WanImageToVideo": WanImageToVideo,
|
||||||
@ -695,4 +738,5 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"TrimVideoLatent": TrimVideoLatent,
|
"TrimVideoLatent": TrimVideoLatent,
|
||||||
"WanCameraImageToVideo": WanCameraImageToVideo,
|
"WanCameraImageToVideo": WanCameraImageToVideo,
|
||||||
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo,
|
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo,
|
||||||
|
"Wan22ImageToVideoLatent": Wan22ImageToVideoLatent,
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.45"
|
__version__ = "0.3.46"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.45"
|
version = "0.3.46"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.23.4
|
comfyui-frontend-package==1.23.4
|
||||||
comfyui-workflow-templates==0.1.40
|
comfyui-workflow-templates==0.1.41
|
||||||
comfyui-embedded-docs==0.2.4
|
comfyui-embedded-docs==0.2.4
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
Loading…
x
Reference in New Issue
Block a user