mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 20:17:30 +00:00
Support LTXV 0.9.5.
Credits: Lightricks team.
This commit is contained in:
@@ -7,7 +7,7 @@ from einops import rearrange
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
@@ -377,12 +377,16 @@ class LTXVModel(torch.nn.Module):
|
||||
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.generator = None
|
||||
self.vae_scale_factors = vae_scale_factors
|
||||
self.dtype = dtype
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
|
||||
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
@@ -416,42 +420,23 @@ class LTXVModel(torch.nn.Module):
|
||||
|
||||
self.patchifier = SymmetricPatchifier(1)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
indices_grid = self.patchifier.get_grid(
|
||||
orig_num_frames=x.shape[2],
|
||||
orig_height=x.shape[3],
|
||||
orig_width=x.shape[4],
|
||||
batch_size=x.shape[0],
|
||||
scale_grid=((1 / frame_rate) * 8, 32, 32),
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
|
||||
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
|
||||
ts *= input_ts
|
||||
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
|
||||
timestep = self.patchifier.patchify(ts)
|
||||
input_x = x.clone()
|
||||
x[:, :, 0] = guiding_latent[:, :, 0]
|
||||
if guiding_latent_noise_scale > 0:
|
||||
if self.generator is None:
|
||||
self.generator = torch.Generator(device=x.device).manual_seed(42)
|
||||
elif self.generator.device != x.device:
|
||||
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
|
||||
|
||||
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
|
||||
scale = guiding_latent_noise_scale * (input_ts ** 2)
|
||||
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
|
||||
|
||||
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
|
||||
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
|
||||
x = self.patchifier.patchify(x)
|
||||
x, latent_coords = self.patchifier.patchify(x)
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords=latent_coords,
|
||||
scale_factors=self.vae_scale_factors,
|
||||
causal_fix=self.causal_temporal_positioning,
|
||||
)
|
||||
|
||||
if keyframe_idxs is not None:
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
|
||||
|
||||
fractional_coords = pixel_coords.to(torch.float32)
|
||||
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
timestep = timestep * 1000.0
|
||||
@@ -459,7 +444,7 @@ class LTXVModel(torch.nn.Module):
|
||||
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
||||
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
||||
|
||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
@@ -519,8 +504,4 @@ class LTXVModel(torch.nn.Module):
|
||||
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
||||
)
|
||||
|
||||
if guiding_latent is not None:
|
||||
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
|
||||
|
||||
# print("res", x)
|
||||
return x
|
||||
|
@@ -6,16 +6,29 @@ from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||
)
|
||||
elif dims_to_append == 0:
|
||||
return x
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
def latent_to_pixel_coords(
|
||||
latent_coords: Tensor, scale_factors: Tuple[int, int, int], causal_fix: bool = False
|
||||
) -> Tensor:
|
||||
"""
|
||||
Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
|
||||
configuration.
|
||||
Args:
|
||||
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
|
||||
containing the latent corner coordinates of each token.
|
||||
scale_factors (Tuple[int, int, int]): The scale factors of the VAE's latent space.
|
||||
causal_fix (bool): Whether to take into account the different temporal scale
|
||||
of the first frame. Default = False for backwards compatibility.
|
||||
Returns:
|
||||
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
||||
"""
|
||||
pixel_coords = (
|
||||
latent_coords
|
||||
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
|
||||
)
|
||||
if causal_fix:
|
||||
# Fix temporal scale for first frame to 1 due to causality
|
||||
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
|
||||
return pixel_coords
|
||||
|
||||
|
||||
class Patchifier(ABC):
|
||||
@@ -44,29 +57,26 @@ class Patchifier(ABC):
|
||||
def patch_size(self):
|
||||
return self._patch_size
|
||||
|
||||
def get_grid(
|
||||
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
|
||||
def get_latent_coords(
|
||||
self, latent_num_frames, latent_height, latent_width, batch_size, device
|
||||
):
|
||||
f = orig_num_frames // self._patch_size[0]
|
||||
h = orig_height // self._patch_size[1]
|
||||
w = orig_width // self._patch_size[2]
|
||||
grid_h = torch.arange(h, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(w, dtype=torch.float32, device=device)
|
||||
grid_f = torch.arange(f, dtype=torch.float32, device=device)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
if scale_grid is not None:
|
||||
for i in range(3):
|
||||
if isinstance(scale_grid[i], Tensor):
|
||||
scale = append_dims(scale_grid[i], grid.ndim - 1)
|
||||
else:
|
||||
scale = scale_grid[i]
|
||||
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
|
||||
|
||||
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
|
||||
return grid
|
||||
"""
|
||||
Return a tensor of shape [batch_size, 3, num_patches] containing the
|
||||
top-left corner latent coordinates of each latent patch.
|
||||
The tensor is repeated for each batch element.
|
||||
"""
|
||||
latent_sample_coords = torch.meshgrid(
|
||||
torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
|
||||
torch.arange(0, latent_height, self._patch_size[1], device=device),
|
||||
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
|
||||
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_coords = rearrange(
|
||||
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
|
||||
)
|
||||
return latent_coords
|
||||
|
||||
|
||||
class SymmetricPatchifier(Patchifier):
|
||||
@@ -74,6 +84,8 @@ class SymmetricPatchifier(Patchifier):
|
||||
self,
|
||||
latents: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
b, _, f, h, w = latents.shape
|
||||
latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
|
||||
latents = rearrange(
|
||||
latents,
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
@@ -81,7 +93,7 @@ class SymmetricPatchifier(Patchifier):
|
||||
p2=self._patch_size[1],
|
||||
p3=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
return latents, latent_coords
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
|
@@ -15,6 +15,7 @@ class CausalConv3d(nn.Module):
|
||||
stride: Union[int, Tuple[int]] = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
spatial_padding_mode: str = "zeros",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -38,7 +39,7 @@ class CausalConv3d(nn.Module):
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
padding_mode="zeros",
|
||||
padding_mode=spatial_padding_mode,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
|
@@ -1,13 +1,15 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
from torch import nn
|
||||
from functools import partial
|
||||
import math
|
||||
from einops import rearrange
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||
from .pixel_norm import PixelNorm
|
||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
class Encoder(nn.Module):
|
||||
@@ -32,7 +34,7 @@ class Encoder(nn.Module):
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
||||
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
|
||||
The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -40,12 +42,13 @@ class Encoder(nn.Module):
|
||||
dims: Union[int, Tuple[int, int]] = 3,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
norm_num_groups: int = 32,
|
||||
patch_size: Union[int, Tuple[int]] = 1,
|
||||
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
||||
latent_log_var: str = "per_channel",
|
||||
spatial_padding_mode: str = "zeros",
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
@@ -65,6 +68,7 @@ class Encoder(nn.Module):
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
@@ -82,6 +86,7 @@ class Encoder(nn.Module):
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
@@ -92,6 +97,7 @@ class Encoder(nn.Module):
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = make_conv_nd(
|
||||
@@ -101,6 +107,7 @@ class Encoder(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=(2, 1, 1),
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = make_conv_nd(
|
||||
@@ -110,6 +117,7 @@ class Encoder(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=(1, 2, 2),
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = make_conv_nd(
|
||||
@@ -119,6 +127,7 @@ class Encoder(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all_x_y":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
@@ -129,6 +138,34 @@ class Encoder(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all_res":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
stride=(2, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space_res":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
stride=(1, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time_res":
|
||||
output_channel = block_params.get("multiplier", 2) * output_channel
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
stride=(2, 1, 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown block: {block_name}")
|
||||
@@ -152,10 +189,18 @@ class Encoder(nn.Module):
|
||||
conv_out_channels *= 2
|
||||
elif latent_log_var == "uniform":
|
||||
conv_out_channels += 1
|
||||
elif latent_log_var == "constant":
|
||||
conv_out_channels += 1
|
||||
elif latent_log_var != "none":
|
||||
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
|
||||
dims,
|
||||
output_channel,
|
||||
conv_out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@@ -197,6 +242,15 @@ class Encoder(nn.Module):
|
||||
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {sample.shape}")
|
||||
elif self.latent_log_var == "constant":
|
||||
sample = sample[:, :-1, ...]
|
||||
approx_ln_0 = (
|
||||
-30
|
||||
) # this is the minimal clamp value in DiagonalGaussianDistribution objects
|
||||
sample = torch.cat(
|
||||
[sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
return sample
|
||||
|
||||
@@ -231,7 +285,7 @@ class Decoder(nn.Module):
|
||||
dims,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
blocks=[("res_x", 1)],
|
||||
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
||||
base_channels: int = 128,
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
@@ -239,6 +293,7 @@ class Decoder(nn.Module):
|
||||
norm_layer: str = "group_norm",
|
||||
causal: bool = True,
|
||||
timestep_conditioning: bool = False,
|
||||
spatial_padding_mode: str = "zeros",
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
@@ -264,6 +319,7 @@ class Decoder(nn.Module):
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
@@ -283,6 +339,7 @@ class Decoder(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "attn_res_x":
|
||||
block = UNetMidBlock3D(
|
||||
@@ -294,6 +351,7 @@ class Decoder(nn.Module):
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
attention_head_dim=block_params["attention_head_dim"],
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||
@@ -306,14 +364,21 @@ class Decoder(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=False,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(2, 1, 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(1, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||
@@ -323,6 +388,7 @@ class Decoder(nn.Module):
|
||||
stride=(2, 2, 2),
|
||||
residual=block_params.get("residual", False),
|
||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown layer: {block_name}")
|
||||
@@ -340,7 +406,13 @@ class Decoder(nn.Module):
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = make_conv_nd(
|
||||
dims, output_channel, out_channels, 3, padding=1, causal=True
|
||||
dims,
|
||||
output_channel,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@@ -433,6 +505,12 @@ class UNetMidBlock3D(nn.Module):
|
||||
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
||||
resnet_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use in the group normalization layers of the resnet blocks.
|
||||
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
||||
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
||||
inject_noise (`bool`, *optional*, defaults to `False`):
|
||||
Whether to inject noise into the hidden states.
|
||||
timestep_conditioning (`bool`, *optional*, defaults to `False`):
|
||||
Whether to condition the hidden states on the timestep.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
||||
@@ -451,6 +529,7 @@ class UNetMidBlock3D(nn.Module):
|
||||
norm_layer: str = "group_norm",
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
spatial_padding_mode: str = "zeros",
|
||||
):
|
||||
super().__init__()
|
||||
resnet_groups = (
|
||||
@@ -476,13 +555,17 @@ class UNetMidBlock3D(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
causal: bool = True,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
timestep_embed = None
|
||||
if self.timestep_conditioning:
|
||||
@@ -507,9 +590,62 @@ class UNetMidBlock3D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SpaceToDepthDownsample(nn.Module):
|
||||
def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.group_size = in_channels * math.prod(stride) // out_channels
|
||||
self.conv = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels // math.prod(stride),
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if self.stride[0] == 2:
|
||||
x = torch.cat(
|
||||
[x[:, :, :1, :, :], x], dim=2
|
||||
) # duplicate first frames for padding
|
||||
|
||||
# skip connection
|
||||
x_in = rearrange(
|
||||
x,
|
||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
|
||||
x_in = x_in.mean(dim=2)
|
||||
|
||||
# conv
|
||||
x = self.conv(x, causal=causal)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
|
||||
x = x + x_in
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
def __init__(
|
||||
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
|
||||
self,
|
||||
dims,
|
||||
in_channels,
|
||||
stride,
|
||||
residual=False,
|
||||
out_channels_reduction_factor=1,
|
||||
spatial_padding_mode="zeros",
|
||||
):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
@@ -523,6 +659,7 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
self.residual = residual
|
||||
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||
@@ -591,6 +728,7 @@ class ResnetBlock3D(nn.Module):
|
||||
norm_layer: str = "group_norm",
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
spatial_padding_mode: str = "zeros",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
@@ -617,6 +755,7 @@ class ResnetBlock3D(nn.Module):
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
if inject_noise:
|
||||
@@ -641,6 +780,7 @@ class ResnetBlock3D(nn.Module):
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
if inject_noise:
|
||||
@@ -801,9 +941,44 @@ class processor(nn.Module):
|
||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
def __init__(self, version=0):
|
||||
def __init__(self, version=0, config=None):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = self.guess_config(version)
|
||||
|
||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
"latent_log_var", "per_channel" if double_z else "none"
|
||||
)
|
||||
|
||||
self.encoder = Encoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config.get("in_channels", 3),
|
||||
out_channels=config["latent_channels"],
|
||||
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
timestep_conditioning=self.timestep_conditioning,
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||
)
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def guess_config(self, version):
|
||||
if version == 0:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
@@ -830,7 +1005,7 @@ class VideoVAE(nn.Module):
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
}
|
||||
else:
|
||||
elif version == 1:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
@@ -866,37 +1041,47 @@ class VideoVAE(nn.Module):
|
||||
"causal_decoder": False,
|
||||
"timestep_conditioning": True,
|
||||
}
|
||||
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
"latent_log_var", "per_channel" if double_z else "none"
|
||||
)
|
||||
|
||||
self.encoder = Encoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config.get("in_channels", 3),
|
||||
out_channels=config["latent_channels"],
|
||||
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
dims=config["dims"],
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
timestep_conditioning=config.get("timestep_conditioning", False),
|
||||
)
|
||||
|
||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||
self.per_channel_statistics = processor()
|
||||
else:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"encoder_blocks": [
|
||||
["res_x", {"num_layers": 4}],
|
||||
["compress_space_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 6}],
|
||||
["compress_time_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 6}],
|
||||
["compress_all_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 2}],
|
||||
["compress_all_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 2}]
|
||||
],
|
||||
"decoder_blocks": [
|
||||
["res_x", {"num_layers": 5, "inject_noise": False}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 5, "inject_noise": False}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 5, "inject_noise": False}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 5, "inject_noise": False}]
|
||||
],
|
||||
"scaling_factor": 1.0,
|
||||
"norm_layer": "pixel_norm",
|
||||
"patch_size": 4,
|
||||
"latent_log_var": "uniform",
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
"timestep_conditioning": True
|
||||
}
|
||||
return config
|
||||
|
||||
def encode(self, x):
|
||||
frames_count = x.shape[2]
|
||||
if ((frames_count - 1) % 8) != 0:
|
||||
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
|
@@ -17,7 +17,11 @@ def make_conv_nd(
|
||||
groups=1,
|
||||
bias=True,
|
||||
causal=False,
|
||||
spatial_padding_mode="zeros",
|
||||
temporal_padding_mode="zeros",
|
||||
):
|
||||
if not (spatial_padding_mode == temporal_padding_mode or causal):
|
||||
raise NotImplementedError("spatial and temporal padding modes must be equal")
|
||||
if dims == 2:
|
||||
return ops.Conv2d(
|
||||
in_channels=in_channels,
|
||||
@@ -28,6 +32,7 @@ def make_conv_nd(
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif dims == 3:
|
||||
if causal:
|
||||
@@ -40,6 +45,7 @@ def make_conv_nd(
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
return ops.Conv3d(
|
||||
in_channels=in_channels,
|
||||
@@ -50,6 +56,7 @@ def make_conv_nd(
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif dims == (2, 1):
|
||||
return DualConv3d(
|
||||
@@ -59,6 +66,7 @@ def make_conv_nd(
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
@@ -18,11 +18,13 @@ class DualConv3d(nn.Module):
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode="zeros",
|
||||
):
|
||||
super(DualConv3d, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.padding_mode = padding_mode
|
||||
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
@@ -108,6 +110,7 @@ class DualConv3d(nn.Module):
|
||||
self.padding1,
|
||||
self.dilation1,
|
||||
self.groups,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
|
||||
if skip_time_conv:
|
||||
@@ -122,6 +125,7 @@ class DualConv3d(nn.Module):
|
||||
self.padding2,
|
||||
self.dilation2,
|
||||
self.groups,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
|
||||
return x
|
||||
@@ -137,7 +141,16 @@ class DualConv3d(nn.Module):
|
||||
stride1 = (self.stride1[1], self.stride1[2])
|
||||
padding1 = (self.padding1[1], self.padding1[2])
|
||||
dilation1 = (self.dilation1[1], self.dilation1[2])
|
||||
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
|
||||
x = F.conv2d(
|
||||
x,
|
||||
weight1,
|
||||
self.bias1,
|
||||
stride1,
|
||||
padding1,
|
||||
dilation1,
|
||||
self.groups,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
|
||||
_, _, h, w = x.shape
|
||||
|
||||
@@ -154,7 +167,16 @@ class DualConv3d(nn.Module):
|
||||
stride2 = self.stride2[0]
|
||||
padding2 = self.padding2[0]
|
||||
dilation2 = self.dilation2[0]
|
||||
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
|
||||
x = F.conv1d(
|
||||
x,
|
||||
weight2,
|
||||
self.bias2,
|
||||
stride2,
|
||||
padding2,
|
||||
dilation2,
|
||||
self.groups,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
||||
|
||||
return x
|
||||
|
Reference in New Issue
Block a user