Support LTXV 0.9.5.

Credits: Lightricks team.
This commit is contained in:
comfyanonymous
2025-03-05 00:13:49 -05:00
parent 745b13649b
commit 93fedd92fe
11 changed files with 661 additions and 141 deletions

View File

@@ -7,7 +7,7 @@ from einops import rearrange
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
def get_timestep_embedding(
@@ -377,12 +377,16 @@ class LTXVModel(torch.nn.Module):
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.generator = None
self.vae_scale_factors = vae_scale_factors
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
@@ -416,42 +420,23 @@ class LTXVModel(torch.nn.Module):
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2],
orig_height=x.shape[3],
orig_width=x.shape[4],
batch_size=x.shape[0],
scale_grid=((1 / frame_rate) * 8, 32, 32),
device=x.device,
)
if guiding_latent is not None:
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
ts *= input_ts
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
timestep = self.patchifier.patchify(ts)
input_x = x.clone()
x[:, :, 0] = guiding_latent[:, :, 0]
if guiding_latent_noise_scale > 0:
if self.generator is None:
self.generator = torch.Generator(device=x.device).manual_seed(42)
elif self.generator.device != x.device:
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
scale = guiding_latent_noise_scale * (input_ts ** 2)
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
orig_shape = list(x.shape)
x = self.patchifier.patchify(x)
x, latent_coords = self.patchifier.patchify(x)
pixel_coords = latent_to_pixel_coords(
latent_coords=latent_coords,
scale_factors=self.vae_scale_factors,
causal_fix=self.causal_temporal_positioning,
)
if keyframe_idxs is not None:
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
fractional_coords = pixel_coords.to(torch.float32)
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
x = self.patchify_proj(x)
timestep = timestep * 1000.0
@@ -459,7 +444,7 @@ class LTXVModel(torch.nn.Module):
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single(
@@ -519,8 +504,4 @@ class LTXVModel(torch.nn.Module):
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
)
if guiding_latent is not None:
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
# print("res", x)
return x

View File

@@ -6,16 +6,29 @@ from einops import rearrange
from torch import Tensor
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
elif dims_to_append == 0:
return x
return x[(...,) + (None,) * dims_to_append]
def latent_to_pixel_coords(
latent_coords: Tensor, scale_factors: Tuple[int, int, int], causal_fix: bool = False
) -> Tensor:
"""
Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
configuration.
Args:
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
containing the latent corner coordinates of each token.
scale_factors (Tuple[int, int, int]): The scale factors of the VAE's latent space.
causal_fix (bool): Whether to take into account the different temporal scale
of the first frame. Default = False for backwards compatibility.
Returns:
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
"""
pixel_coords = (
latent_coords
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
)
if causal_fix:
# Fix temporal scale for first frame to 1 due to causality
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords
class Patchifier(ABC):
@@ -44,29 +57,26 @@ class Patchifier(ABC):
def patch_size(self):
return self._patch_size
def get_grid(
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
def get_latent_coords(
self, latent_num_frames, latent_height, latent_width, batch_size, device
):
f = orig_num_frames // self._patch_size[0]
h = orig_height // self._patch_size[1]
w = orig_width // self._patch_size[2]
grid_h = torch.arange(h, dtype=torch.float32, device=device)
grid_w = torch.arange(w, dtype=torch.float32, device=device)
grid_f = torch.arange(f, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
if scale_grid is not None:
for i in range(3):
if isinstance(scale_grid[i], Tensor):
scale = append_dims(scale_grid[i], grid.ndim - 1)
else:
scale = scale_grid[i]
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
return grid
"""
Return a tensor of shape [batch_size, 3, num_patches] containing the
top-left corner latent coordinates of each latent patch.
The tensor is repeated for each batch element.
"""
latent_sample_coords = torch.meshgrid(
torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
torch.arange(0, latent_height, self._patch_size[1], device=device),
torch.arange(0, latent_width, self._patch_size[2], device=device),
indexing="ij",
)
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
latent_coords = rearrange(
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
)
return latent_coords
class SymmetricPatchifier(Patchifier):
@@ -74,6 +84,8 @@ class SymmetricPatchifier(Patchifier):
self,
latents: Tensor,
) -> Tuple[Tensor, Tensor]:
b, _, f, h, w = latents.shape
latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
latents = rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
@@ -81,7 +93,7 @@ class SymmetricPatchifier(Patchifier):
p2=self._patch_size[1],
p3=self._patch_size[2],
)
return latents
return latents, latent_coords
def unpatchify(
self,

View File

@@ -15,6 +15,7 @@ class CausalConv3d(nn.Module):
stride: Union[int, Tuple[int]] = 1,
dilation: int = 1,
groups: int = 1,
spatial_padding_mode: str = "zeros",
**kwargs,
):
super().__init__()
@@ -38,7 +39,7 @@ class CausalConv3d(nn.Module):
stride=stride,
dilation=dilation,
padding=padding,
padding_mode="zeros",
padding_mode=spatial_padding_mode,
groups=groups,
)

View File

@@ -1,13 +1,15 @@
from __future__ import annotations
import torch
from torch import nn
from functools import partial
import math
from einops import rearrange
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
ops = comfy.ops.disable_weight_init
class Encoder(nn.Module):
@@ -32,7 +34,7 @@ class Encoder(nn.Module):
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
latent_log_var (`str`, *optional*, defaults to `per_channel`):
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
"""
def __init__(
@@ -40,12 +42,13 @@ class Encoder(nn.Module):
dims: Union[int, Tuple[int, int]] = 3,
in_channels: int = 3,
out_channels: int = 3,
blocks=[("res_x", 1)],
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
base_channels: int = 128,
norm_num_groups: int = 32,
patch_size: Union[int, Tuple[int]] = 1,
norm_layer: str = "group_norm", # group_norm, pixel_norm
latent_log_var: str = "per_channel",
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.patch_size = patch_size
@@ -65,6 +68,7 @@ class Encoder(nn.Module):
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.down_blocks = nn.ModuleList([])
@@ -82,6 +86,7 @@ class Encoder(nn.Module):
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
@@ -92,6 +97,7 @@ class Encoder(nn.Module):
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = make_conv_nd(
@@ -101,6 +107,7 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(2, 1, 1),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = make_conv_nd(
@@ -110,6 +117,7 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(1, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
block = make_conv_nd(
@@ -119,6 +127,7 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(2, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
@@ -129,6 +138,34 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(2, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(2, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unknown block: {block_name}")
@@ -152,10 +189,18 @@ class Encoder(nn.Module):
conv_out_channels *= 2
elif latent_log_var == "uniform":
conv_out_channels += 1
elif latent_log_var == "constant":
conv_out_channels += 1
elif latent_log_var != "none":
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
self.conv_out = make_conv_nd(
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
dims,
output_channel,
conv_out_channels,
3,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.gradient_checkpointing = False
@@ -197,6 +242,15 @@ class Encoder(nn.Module):
sample = torch.cat([sample, repeated_last_channel], dim=1)
else:
raise ValueError(f"Invalid input shape: {sample.shape}")
elif self.latent_log_var == "constant":
sample = sample[:, :-1, ...]
approx_ln_0 = (
-30
) # this is the minimal clamp value in DiagonalGaussianDistribution objects
sample = torch.cat(
[sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
dim=1,
)
return sample
@@ -231,7 +285,7 @@ class Decoder(nn.Module):
dims,
in_channels: int = 3,
out_channels: int = 3,
blocks=[("res_x", 1)],
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
base_channels: int = 128,
layers_per_block: int = 2,
norm_num_groups: int = 32,
@@ -239,6 +293,7 @@ class Decoder(nn.Module):
norm_layer: str = "group_norm",
causal: bool = True,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.patch_size = patch_size
@@ -264,6 +319,7 @@ class Decoder(nn.Module):
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.up_blocks = nn.ModuleList([])
@@ -283,6 +339,7 @@ class Decoder(nn.Module):
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
@@ -294,6 +351,7 @@ class Decoder(nn.Module):
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"],
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
@@ -306,14 +364,21 @@ class Decoder(nn.Module):
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=False,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
dims=dims,
in_channels=input_channel,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
dims=dims,
in_channels=input_channel,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
output_channel = output_channel // block_params.get("multiplier", 1)
@@ -323,6 +388,7 @@ class Decoder(nn.Module):
stride=(2, 2, 2),
residual=block_params.get("residual", False),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unknown layer: {block_name}")
@@ -340,7 +406,13 @@ class Decoder(nn.Module):
self.conv_act = nn.SiLU()
self.conv_out = make_conv_nd(
dims, output_channel, out_channels, 3, padding=1, causal=True
dims,
output_channel,
out_channels,
3,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.gradient_checkpointing = False
@@ -433,6 +505,12 @@ class UNetMidBlock3D(nn.Module):
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
inject_noise (`bool`, *optional*, defaults to `False`):
Whether to inject noise into the hidden states.
timestep_conditioning (`bool`, *optional*, defaults to `False`):
Whether to condition the hidden states on the timestep.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
@@ -451,6 +529,7 @@ class UNetMidBlock3D(nn.Module):
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
):
super().__init__()
resnet_groups = (
@@ -476,13 +555,17 @@ class UNetMidBlock3D(nn.Module):
norm_layer=norm_layer,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
for _ in range(num_layers)
]
)
def forward(
self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
self,
hidden_states: torch.FloatTensor,
causal: bool = True,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
timestep_embed = None
if self.timestep_conditioning:
@@ -507,9 +590,62 @@ class UNetMidBlock3D(nn.Module):
return hidden_states
class SpaceToDepthDownsample(nn.Module):
def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
super().__init__()
self.stride = stride
self.group_size = in_channels * math.prod(stride) // out_channels
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=out_channels // math.prod(stride),
kernel_size=3,
stride=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
def forward(self, x, causal: bool = True):
if self.stride[0] == 2:
x = torch.cat(
[x[:, :, :1, :, :], x], dim=2
) # duplicate first frames for padding
# skip connection
x_in = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
x_in = x_in.mean(dim=2)
# conv
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x = x + x_in
return x
class DepthToSpaceUpsample(nn.Module):
def __init__(
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
self,
dims,
in_channels,
stride,
residual=False,
out_channels_reduction_factor=1,
spatial_padding_mode="zeros",
):
super().__init__()
self.stride = stride
@@ -523,6 +659,7 @@ class DepthToSpaceUpsample(nn.Module):
kernel_size=3,
stride=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
@@ -591,6 +728,7 @@ class ResnetBlock3D(nn.Module):
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.in_channels = in_channels
@@ -617,6 +755,7 @@ class ResnetBlock3D(nn.Module):
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
if inject_noise:
@@ -641,6 +780,7 @@ class ResnetBlock3D(nn.Module):
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
if inject_noise:
@@ -801,9 +941,44 @@ class processor(nn.Module):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
def __init__(self, version=0):
def __init__(self, version=0, config=None):
super().__init__()
if config is None:
config = self.guess_config(version)
self.timestep_conditioning = config.get("timestep_conditioning", False)
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
self.encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
)
self.decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=self.timestep_conditioning,
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
)
self.per_channel_statistics = processor()
def guess_config(self, version):
if version == 0:
config = {
"_class_name": "CausalVideoAutoencoder",
@@ -830,7 +1005,7 @@ class VideoVAE(nn.Module):
"use_quant_conv": False,
"causal_decoder": False,
}
else:
elif version == 1:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
@@ -866,37 +1041,47 @@ class VideoVAE(nn.Module):
"causal_decoder": False,
"timestep_conditioning": True,
}
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
self.encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
)
self.decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=config.get("timestep_conditioning", False),
)
self.timestep_conditioning = config.get("timestep_conditioning", False)
self.per_channel_statistics = processor()
else:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"encoder_blocks": [
["res_x", {"num_layers": 4}],
["compress_space_res", {"multiplier": 2}],
["res_x", {"num_layers": 6}],
["compress_time_res", {"multiplier": 2}],
["res_x", {"num_layers": 6}],
["compress_all_res", {"multiplier": 2}],
["res_x", {"num_layers": 2}],
["compress_all_res", {"multiplier": 2}],
["res_x", {"num_layers": 2}]
],
"decoder_blocks": [
["res_x", {"num_layers": 5, "inject_noise": False}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 5, "inject_noise": False}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 5, "inject_noise": False}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 5, "inject_noise": False}]
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
"timestep_conditioning": True
}
return config
def encode(self, x):
frames_count = x.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)

View File

@@ -17,7 +17,11 @@ def make_conv_nd(
groups=1,
bias=True,
causal=False,
spatial_padding_mode="zeros",
temporal_padding_mode="zeros",
):
if not (spatial_padding_mode == temporal_padding_mode or causal):
raise NotImplementedError("spatial and temporal padding modes must be equal")
if dims == 2:
return ops.Conv2d(
in_channels=in_channels,
@@ -28,6 +32,7 @@ def make_conv_nd(
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=spatial_padding_mode,
)
elif dims == 3:
if causal:
@@ -40,6 +45,7 @@ def make_conv_nd(
dilation=dilation,
groups=groups,
bias=bias,
spatial_padding_mode=spatial_padding_mode,
)
return ops.Conv3d(
in_channels=in_channels,
@@ -50,6 +56,7 @@ def make_conv_nd(
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=spatial_padding_mode,
)
elif dims == (2, 1):
return DualConv3d(
@@ -59,6 +66,7 @@ def make_conv_nd(
stride=stride,
padding=padding,
bias=bias,
padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unsupported dimensions: {dims}")

View File

@@ -18,11 +18,13 @@ class DualConv3d(nn.Module):
dilation: Union[int, Tuple[int, int, int]] = 1,
groups=1,
bias=True,
padding_mode="zeros",
):
super(DualConv3d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.padding_mode = padding_mode
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
@@ -108,6 +110,7 @@ class DualConv3d(nn.Module):
self.padding1,
self.dilation1,
self.groups,
padding_mode=self.padding_mode,
)
if skip_time_conv:
@@ -122,6 +125,7 @@ class DualConv3d(nn.Module):
self.padding2,
self.dilation2,
self.groups,
padding_mode=self.padding_mode,
)
return x
@@ -137,7 +141,16 @@ class DualConv3d(nn.Module):
stride1 = (self.stride1[1], self.stride1[2])
padding1 = (self.padding1[1], self.padding1[2])
dilation1 = (self.dilation1[1], self.dilation1[2])
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
x = F.conv2d(
x,
weight1,
self.bias1,
stride1,
padding1,
dilation1,
self.groups,
padding_mode=self.padding_mode,
)
_, _, h, w = x.shape
@@ -154,7 +167,16 @@ class DualConv3d(nn.Module):
stride2 = self.stride2[0]
padding2 = self.padding2[0]
dilation2 = self.dilation2[0]
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
x = F.conv1d(
x,
weight2,
self.bias2,
stride2,
padding2,
dilation2,
self.groups,
padding_mode=self.padding_mode,
)
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return x