mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-15 05:57:57 +00:00
Mochi VAE encoder.
This commit is contained in:
@@ -2,12 +2,16 @@
|
||||
#adapted to ComfyUI
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
@@ -158,8 +162,10 @@ class ResBlock(nn.Module):
|
||||
*,
|
||||
affine: bool = True,
|
||||
attn_block: Optional[nn.Module] = None,
|
||||
padding_mode: str = "replicate",
|
||||
causal: bool = True,
|
||||
prune_bottleneck: bool = False,
|
||||
padding_mode: str,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
@@ -170,23 +176,23 @@ class ResBlock(nn.Module):
|
||||
nn.SiLU(inplace=True),
|
||||
PConv3d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
out_channels=channels // 2 if prune_bottleneck else channels,
|
||||
kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding_mode=padding_mode,
|
||||
bias=True,
|
||||
# causal=causal,
|
||||
bias=bias,
|
||||
causal=causal,
|
||||
),
|
||||
norm_fn(channels, affine=affine),
|
||||
nn.SiLU(inplace=True),
|
||||
PConv3d(
|
||||
in_channels=channels,
|
||||
in_channels=channels // 2 if prune_bottleneck else channels,
|
||||
out_channels=channels,
|
||||
kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding_mode=padding_mode,
|
||||
bias=True,
|
||||
# causal=causal,
|
||||
bias=bias,
|
||||
causal=causal,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -206,6 +212,81 @@ class ResBlock(nn.Module):
|
||||
return self.attn_block(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
head_dim: int = 32,
|
||||
qkv_bias: bool = False,
|
||||
out_bias: bool = True,
|
||||
qk_norm: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = dim // head_dim
|
||||
self.qk_norm = qk_norm
|
||||
|
||||
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
|
||||
self.out = nn.Linear(dim, dim, bias=out_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute temporal self-attention.
|
||||
|
||||
Args:
|
||||
x: Input tensor. Shape: [B, C, T, H, W].
|
||||
chunk_size: Chunk size for large tensors.
|
||||
|
||||
Returns:
|
||||
x: Output tensor. Shape: [B, C, T, H, W].
|
||||
"""
|
||||
B, _, T, H, W = x.shape
|
||||
|
||||
if T == 1:
|
||||
# No attention for single frame.
|
||||
x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
|
||||
qkv = self.qkv(x)
|
||||
_, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
|
||||
x = self.out(x)
|
||||
return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
|
||||
|
||||
# 1D temporal attention.
|
||||
x = rearrange(x, "B C t h w -> (B h w) t C")
|
||||
qkv = self.qkv(x)
|
||||
|
||||
# Input: qkv with shape [B, t, 3 * num_heads * head_dim]
|
||||
# Output: x with shape [B, num_heads, t, head_dim]
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)
|
||||
|
||||
if self.qk_norm:
|
||||
q = F.normalize(q, p=2, dim=-1)
|
||||
k = F.normalize(k, p=2, dim=-1)
|
||||
|
||||
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
|
||||
|
||||
assert x.size(0) == q.size(0)
|
||||
|
||||
x = self.out(x)
|
||||
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
**attn_kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm = norm_fn(dim)
|
||||
self.attn = Attention(dim, **attn_kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.attn(self.norm(x))
|
||||
|
||||
|
||||
class CausalUpsampleBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -244,14 +325,9 @@ class CausalUpsampleBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
|
||||
assert has_attention is False #NOTE: if this is ever true add back the attention code.
|
||||
|
||||
attn_block = None #AttentionBlock(channels) if has_attention else None
|
||||
|
||||
return ResBlock(
|
||||
channels, affine=True, attn_block=attn_block, **block_kwargs
|
||||
)
|
||||
def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
|
||||
attn_block = AttentionBlock(channels) if has_attention else None
|
||||
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
|
||||
|
||||
|
||||
class DownsampleBlock(nn.Module):
|
||||
@@ -288,8 +364,9 @@ class DownsampleBlock(nn.Module):
|
||||
out_channels=out_channels,
|
||||
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
|
||||
# First layer in each block always uses replicate padding
|
||||
padding_mode="replicate",
|
||||
bias=True,
|
||||
bias=block_kwargs["bias"],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -382,7 +459,7 @@ class Decoder(nn.Module):
|
||||
blocks = []
|
||||
|
||||
first_block = [
|
||||
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
|
||||
ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
|
||||
] # Input layer.
|
||||
# First set of blocks preserve channel count.
|
||||
for _ in range(num_res_blocks[-1]):
|
||||
@@ -452,11 +529,165 @@ class Decoder(nn.Module):
|
||||
|
||||
return self.output_proj(x).contiguous()
|
||||
|
||||
class LatentDistribution:
|
||||
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
|
||||
"""Initialize latent distribution.
|
||||
|
||||
Args:
|
||||
mean: Mean of the distribution. Shape: [B, C, T, H, W].
|
||||
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
|
||||
"""
|
||||
assert mean.shape == logvar.shape
|
||||
self.mean = mean
|
||||
self.logvar = logvar
|
||||
|
||||
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
|
||||
if temperature == 0.0:
|
||||
return self.mean
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
|
||||
else:
|
||||
assert noise.device == self.mean.device
|
||||
noise = noise.to(self.mean.dtype)
|
||||
|
||||
if temperature != 1.0:
|
||||
raise NotImplementedError(f"Temperature {temperature} is not supported.")
|
||||
|
||||
# Just Gaussian sample with no scaling of variance.
|
||||
return noise * torch.exp(self.logvar * 0.5) + self.mean
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels: int,
|
||||
base_channels: int,
|
||||
channel_multipliers: List[int],
|
||||
num_res_blocks: List[int],
|
||||
latent_dim: int,
|
||||
temporal_reductions: List[int],
|
||||
spatial_reductions: List[int],
|
||||
prune_bottlenecks: List[bool],
|
||||
has_attentions: List[bool],
|
||||
affine: bool = True,
|
||||
bias: bool = True,
|
||||
input_is_conv_1x1: bool = False,
|
||||
padding_mode: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.temporal_reductions = temporal_reductions
|
||||
self.spatial_reductions = spatial_reductions
|
||||
self.base_channels = base_channels
|
||||
self.channel_multipliers = channel_multipliers
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
self.fourier_features = FourierFeatures()
|
||||
ch = [mult * base_channels for mult in channel_multipliers]
|
||||
num_down_blocks = len(ch) - 1
|
||||
assert len(num_res_blocks) == num_down_blocks + 2
|
||||
|
||||
layers = (
|
||||
[ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
|
||||
if not input_is_conv_1x1
|
||||
else [Conv1x1(in_channels, ch[0])]
|
||||
)
|
||||
|
||||
assert len(prune_bottlenecks) == num_down_blocks + 2
|
||||
assert len(has_attentions) == num_down_blocks + 2
|
||||
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
|
||||
|
||||
for _ in range(num_res_blocks[0]):
|
||||
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
|
||||
prune_bottlenecks = prune_bottlenecks[1:]
|
||||
has_attentions = has_attentions[1:]
|
||||
|
||||
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
|
||||
for i in range(num_down_blocks):
|
||||
layer = DownsampleBlock(
|
||||
ch[i],
|
||||
ch[i + 1],
|
||||
num_res_blocks=num_res_blocks[i + 1],
|
||||
temporal_reduction=temporal_reductions[i],
|
||||
spatial_reduction=spatial_reductions[i],
|
||||
prune_bottleneck=prune_bottlenecks[i],
|
||||
has_attention=has_attentions[i],
|
||||
affine=affine,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
layers.append(layer)
|
||||
|
||||
# Additional blocks.
|
||||
for _ in range(num_res_blocks[-1]):
|
||||
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
# Output layers.
|
||||
self.output_norm = norm_fn(ch[-1])
|
||||
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
|
||||
|
||||
@property
|
||||
def temporal_downsample(self):
|
||||
return math.prod(self.temporal_reductions)
|
||||
|
||||
@property
|
||||
def spatial_downsample(self):
|
||||
return math.prod(self.spatial_reductions)
|
||||
|
||||
def forward(self, x) -> LatentDistribution:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
|
||||
|
||||
Returns:
|
||||
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
|
||||
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
|
||||
logvar: Shape: [B, latent_dim, t, h, w].
|
||||
"""
|
||||
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
|
||||
x = self.fourier_features(x)
|
||||
|
||||
x = self.layers(x)
|
||||
|
||||
x = self.output_norm(x)
|
||||
x = F.silu(x, inplace=True)
|
||||
x = self.output_proj(x)
|
||||
|
||||
means, logvar = torch.chunk(x, 2, dim=1)
|
||||
|
||||
assert means.ndim == 5
|
||||
assert logvar.shape == means.shape
|
||||
assert means.size(1) == self.latent_dim
|
||||
|
||||
return LatentDistribution(means, logvar)
|
||||
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = None #TODO once the model releases
|
||||
self.encoder = Encoder(
|
||||
in_channels=15,
|
||||
base_channels=64,
|
||||
channel_multipliers=[1, 2, 4, 6],
|
||||
num_res_blocks=[3, 3, 4, 6, 3],
|
||||
latent_dim=12,
|
||||
temporal_reductions=[1, 2, 3],
|
||||
spatial_reductions=[2, 2, 2],
|
||||
prune_bottlenecks=[False, False, False, False, False],
|
||||
has_attentions=[False, True, True, True, True],
|
||||
affine=True,
|
||||
bias=True,
|
||||
input_is_conv_1x1=True,
|
||||
padding_mode="replicate"
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
out_channels=3,
|
||||
base_channels=128,
|
||||
@@ -474,7 +705,7 @@ class VideoVAE(nn.Module):
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
return self.encoder(x).mode()
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(x)
|
||||
|
Reference in New Issue
Block a user