mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 21:16:09 +00:00
Support SVD img2vid model.
This commit is contained in:
@@ -5,6 +5,8 @@ import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from functools import partial
|
||||
|
||||
from .util import (
|
||||
checkpoint,
|
||||
@@ -12,8 +14,9 @@ from .util import (
|
||||
zero_module,
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
AlphaBlender,
|
||||
)
|
||||
from ..attention import SpatialTransformer
|
||||
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
||||
from comfy.ldm.util import exists
|
||||
import comfy.ops
|
||||
|
||||
@@ -29,10 +32,15 @@ class TimestepBlock(nn.Module):
|
||||
"""
|
||||
|
||||
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
|
||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
|
||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
||||
for layer in ts:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
if isinstance(layer, VideoResBlock):
|
||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||
elif isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialVideoTransformer):
|
||||
x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
|
||||
transformer_options["current_index"] += 1
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context, transformer_options)
|
||||
if "current_index" in transformer_options:
|
||||
@@ -145,6 +153,9 @@ class ResBlock(TimestepBlock):
|
||||
use_checkpoint=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=3,
|
||||
exchange_temb_dims=False,
|
||||
skip_t_emb=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=comfy.ops
|
||||
@@ -157,11 +168,17 @@ class ResBlock(TimestepBlock):
|
||||
self.use_conv = use_conv
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
self.exchange_temb_dims = exchange_temb_dims
|
||||
|
||||
if isinstance(kernel_size, list):
|
||||
padding = [k // 2 for k in kernel_size]
|
||||
else:
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
nn.GroupNorm(32, channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
|
||||
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
@@ -175,19 +192,24 @@ class ResBlock(TimestepBlock):
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
|
||||
),
|
||||
)
|
||||
self.skip_t_emb = skip_t_emb
|
||||
if self.skip_t_emb:
|
||||
self.emb_layers = None
|
||||
self.exchange_temb_dims = False
|
||||
else:
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
|
||||
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -195,7 +217,7 @@ class ResBlock(TimestepBlock):
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = operations.conv_nd(
|
||||
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
|
||||
dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device
|
||||
)
|
||||
else:
|
||||
self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
|
||||
@@ -221,19 +243,110 @@ class ResBlock(TimestepBlock):
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
|
||||
emb_out = None
|
||||
if not self.skip_t_emb:
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||
h = out_norm(h) * (1 + scale) + shift
|
||||
h = out_norm(h)
|
||||
if emb_out is not None:
|
||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||
h *= (1 + scale)
|
||||
h += shift
|
||||
h = out_rest(h)
|
||||
else:
|
||||
h = h + emb_out
|
||||
if emb_out is not None:
|
||||
if self.exchange_temb_dims:
|
||||
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class VideoResBlock(ResBlock):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
emb_channels: int,
|
||||
dropout: float,
|
||||
video_kernel_size=3,
|
||||
merge_strategy: str = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
out_channels=None,
|
||||
use_conv: bool = False,
|
||||
use_scale_shift_norm: bool = False,
|
||||
dims: int = 2,
|
||||
use_checkpoint: bool = False,
|
||||
up: bool = False,
|
||||
down: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=comfy.ops
|
||||
):
|
||||
super().__init__(
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=out_channels,
|
||||
use_conv=use_conv,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
up=up,
|
||||
down=down,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
self.time_stack = ResBlock(
|
||||
default(out_channels, channels),
|
||||
emb_channels,
|
||||
dropout=dropout,
|
||||
dims=3,
|
||||
out_channels=default(out_channels, channels),
|
||||
use_scale_shift_norm=False,
|
||||
use_conv=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=video_kernel_size,
|
||||
use_checkpoint=use_checkpoint,
|
||||
exchange_temb_dims=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
self.time_mixer = AlphaBlender(
|
||||
alpha=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
rearrange_pattern="b t -> b 1 t 1 1",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: th.Tensor,
|
||||
emb: th.Tensor,
|
||||
num_video_frames: int,
|
||||
image_only_indicator = None,
|
||||
) -> th.Tensor:
|
||||
x = super().forward(x, emb)
|
||||
|
||||
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||
|
||||
x = self.time_stack(
|
||||
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
|
||||
)
|
||||
x = self.time_mixer(
|
||||
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
|
||||
)
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
return x
|
||||
|
||||
|
||||
class Timestep(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@@ -310,6 +423,16 @@ class UNetModel(nn.Module):
|
||||
adm_in_channels=None,
|
||||
transformer_depth_middle=None,
|
||||
transformer_depth_output=None,
|
||||
use_temporal_resblock=False,
|
||||
use_temporal_attention=False,
|
||||
time_context_dim=None,
|
||||
extra_ff_mix_layer=False,
|
||||
use_spatial_context=False,
|
||||
merge_strategy=None,
|
||||
merge_factor=0.0,
|
||||
video_kernel_size=None,
|
||||
disable_temporal_crossattention=False,
|
||||
max_ddpm_temb_period=10000,
|
||||
device=None,
|
||||
operations=comfy.ops,
|
||||
):
|
||||
@@ -364,8 +487,12 @@ class UNetModel(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.use_temporal_resblocks = use_temporal_resblock
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
self.default_num_video_frames = None
|
||||
self.default_image_only_indicator = None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||
@@ -402,13 +529,104 @@ class UNetModel(nn.Module):
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
|
||||
def get_attention_layer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=1,
|
||||
context_dim=None,
|
||||
use_checkpoint=False,
|
||||
disable_self_attn=False,
|
||||
):
|
||||
if use_temporal_attention:
|
||||
return SpatialVideoTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=depth,
|
||||
context_dim=context_dim,
|
||||
time_context_dim=time_context_dim,
|
||||
dropout=dropout,
|
||||
ff_in=extra_ff_mix_layer,
|
||||
use_spatial_context=use_spatial_context,
|
||||
merge_strategy=merge_strategy,
|
||||
merge_factor=merge_factor,
|
||||
checkpoint=use_checkpoint,
|
||||
use_linear=use_linear_in_transformer,
|
||||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
max_time_embed_period=max_ddpm_temb_period,
|
||||
dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
return SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def get_resblock(
|
||||
merge_factor,
|
||||
merge_strategy,
|
||||
video_kernel_size,
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels,
|
||||
dims,
|
||||
use_checkpoint,
|
||||
use_scale_shift_norm,
|
||||
down=False,
|
||||
up=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=comfy.ops
|
||||
):
|
||||
if self.use_temporal_resblocks:
|
||||
return VideoResBlock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
channels=ch,
|
||||
emb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=down,
|
||||
up=up,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
else:
|
||||
return ResBlock(
|
||||
channels=ch,
|
||||
emb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_channels,
|
||||
use_checkpoint=use_checkpoint,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=down,
|
||||
up=up,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(self.num_res_blocks[level]):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
@@ -435,11 +653,9 @@ class UNetModel(nn.Module):
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(SpatialTransformer(
|
||||
layers.append(get_attention_layer(
|
||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
@@ -448,10 +664,13 @@ class UNetModel(nn.Module):
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
@@ -481,10 +700,14 @@ class UNetModel(nn.Module):
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
mid_block = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=None,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
@@ -493,15 +716,18 @@ class UNetModel(nn.Module):
|
||||
operations=operations
|
||||
)]
|
||||
if transformer_depth_middle >= 0:
|
||||
mid_block += [SpatialTransformer( # always uses a self-attn
|
||||
mid_block += [get_attention_layer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=None,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
@@ -517,10 +743,13 @@ class UNetModel(nn.Module):
|
||||
for i in range(self.num_res_blocks[level] + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch + ich,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=model_channels * mult,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
@@ -548,19 +777,21 @@ class UNetModel(nn.Module):
|
||||
|
||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
get_attention_layer(
|
||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint
|
||||
)
|
||||
)
|
||||
if level and i == self.num_res_blocks[level]:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
@@ -602,6 +833,10 @@ class UNetModel(nn.Module):
|
||||
transformer_options["current_index"] = 0
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
|
||||
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
||||
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
|
||||
time_context = kwargs.get("time_context", None)
|
||||
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
@@ -616,7 +851,7 @@ class UNetModel(nn.Module):
|
||||
h = x.type(self.dtype)
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
transformer_options["block"] = ("input", id)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'input')
|
||||
if "input_block_patch" in transformer_patches:
|
||||
patch = transformer_patches["input_block_patch"]
|
||||
@@ -630,9 +865,10 @@ class UNetModel(nn.Module):
|
||||
h = p(h, transformer_options)
|
||||
|
||||
transformer_options["block"] = ("middle", 0)
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'middle')
|
||||
|
||||
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
transformer_options["block"] = ("output", id)
|
||||
hsp = hs.pop()
|
||||
@@ -649,7 +885,7 @@ class UNetModel(nn.Module):
|
||||
output_shape = hs[-1].shape
|
||||
else:
|
||||
output_shape = None
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
return self.id_predictor(h)
|
||||
|
@@ -13,11 +13,78 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
from einops import repeat, rearrange
|
||||
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
import comfy.ops
|
||||
|
||||
class AlphaBlender(nn.Module):
|
||||
strategies = ["learned", "fixed", "learned_with_images"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alpha: float,
|
||||
merge_strategy: str = "learned_with_images",
|
||||
rearrange_pattern: str = "b t -> (b t) 1 1",
|
||||
):
|
||||
super().__init__()
|
||||
self.merge_strategy = merge_strategy
|
||||
self.rearrange_pattern = rearrange_pattern
|
||||
|
||||
assert (
|
||||
merge_strategy in self.strategies
|
||||
), f"merge_strategy needs to be in {self.strategies}"
|
||||
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif (
|
||||
self.merge_strategy == "learned"
|
||||
or self.merge_strategy == "learned_with_images"
|
||||
):
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
|
||||
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
|
||||
if self.merge_strategy == "fixed":
|
||||
# make shape compatible
|
||||
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
|
||||
alpha = self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
alpha = torch.sigmoid(self.mix_factor)
|
||||
# make shape compatible
|
||||
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
||||
elif self.merge_strategy == "learned_with_images":
|
||||
assert image_only_indicator is not None, "need image_only_indicator ..."
|
||||
alpha = torch.where(
|
||||
image_only_indicator.bool(),
|
||||
torch.ones(1, 1, device=image_only_indicator.device),
|
||||
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
|
||||
)
|
||||
alpha = rearrange(alpha, self.rearrange_pattern)
|
||||
# make shape compatible
|
||||
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return alpha
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_spatial,
|
||||
x_temporal,
|
||||
image_only_indicator=None,
|
||||
) -> torch.Tensor:
|
||||
alpha = self.get_alpha(image_only_indicator)
|
||||
x = (
|
||||
alpha.to(x_spatial.dtype) * x_spatial
|
||||
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if schedule == "linear":
|
||||
betas = (
|
||||
|
Reference in New Issue
Block a user