mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 05:25:23 +00:00
Initial support for the stable audio open model.
This commit is contained in:
276
comfy/ldm/audio/autoencoder.py
Normal file
276
comfy/ldm/audio/autoencoder.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Literal, Dict, Any
|
||||
import math
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def vae_sample(mean, scale):
|
||||
stdev = nn.functional.softplus(scale) + 1e-4
|
||||
var = stdev * stdev
|
||||
logvar = torch.log(var)
|
||||
latents = torch.randn_like(mean) * stdev + mean
|
||||
|
||||
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
||||
|
||||
return latents, kl
|
||||
|
||||
class VAEBottleneck(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.is_discrete = False
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
info = {}
|
||||
|
||||
mean, scale = x.chunk(2, dim=1)
|
||||
|
||||
x, kl = vae_sample(mean, scale)
|
||||
|
||||
info["kl"] = kl
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def snake_beta(x, alpha, beta):
|
||||
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
|
||||
|
||||
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
|
||||
class SnakeBeta(nn.Module):
|
||||
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
||||
super(SnakeBeta, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # linear scale alphas initialized to ones
|
||||
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
# self.alpha.requires_grad = alpha_trainable
|
||||
# self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T]
|
||||
beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = snake_beta(x, alpha, beta)
|
||||
|
||||
return x
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||
if activation == "elu":
|
||||
act = torch.nn.ELU()
|
||||
elif activation == "snake":
|
||||
act = SnakeBeta(channels)
|
||||
elif activation == "none":
|
||||
act = torch.nn.Identity()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation {activation}")
|
||||
|
||||
if antialias:
|
||||
act = Activation1d(act)
|
||||
|
||||
return act
|
||||
|
||||
|
||||
class ResidualUnit(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
||||
super().__init__()
|
||||
|
||||
self.dilation = dilation
|
||||
|
||||
padding = (dilation * (7-1)) // 2
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=7, dilation=dilation, padding=padding),
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
||||
kernel_size=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
|
||||
#x = checkpoint(self.layers, x)
|
||||
x = self.layers(x)
|
||||
|
||||
return x + res
|
||||
|
||||
class EncoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
||||
super().__init__()
|
||||
|
||||
if use_nearest_upsample:
|
||||
upsample_layer = nn.Sequential(
|
||||
nn.Upsample(scale_factor=stride, mode="nearest"),
|
||||
WNConv1d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=2*stride,
|
||||
stride=1,
|
||||
bias=False,
|
||||
padding='same')
|
||||
)
|
||||
else:
|
||||
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||
upsample_layer,
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=1, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=3, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=9, use_snake=use_snake),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
class OobleckEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=2,
|
||||
channels=128,
|
||||
latent_dim=32,
|
||||
c_mults = [1, 2, 4, 8],
|
||||
strides = [2, 4, 8, 8],
|
||||
use_snake=False,
|
||||
antialias_activation=False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
c_mults = [1] + c_mults
|
||||
|
||||
self.depth = len(c_mults)
|
||||
|
||||
layers = [
|
||||
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
||||
]
|
||||
|
||||
for i in range(self.depth-1):
|
||||
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
||||
|
||||
layers += [
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
||||
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
||||
]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class OobleckDecoder(nn.Module):
|
||||
def __init__(self,
|
||||
out_channels=2,
|
||||
channels=128,
|
||||
latent_dim=32,
|
||||
c_mults = [1, 2, 4, 8],
|
||||
strides = [2, 4, 8, 8],
|
||||
use_snake=False,
|
||||
antialias_activation=False,
|
||||
use_nearest_upsample=False,
|
||||
final_tanh=True):
|
||||
super().__init__()
|
||||
|
||||
c_mults = [1] + c_mults
|
||||
|
||||
self.depth = len(c_mults)
|
||||
|
||||
layers = [
|
||||
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
||||
]
|
||||
|
||||
for i in range(self.depth-1, 0, -1):
|
||||
layers += [DecoderBlock(
|
||||
in_channels=c_mults[i]*channels,
|
||||
out_channels=c_mults[i-1]*channels,
|
||||
stride=strides[i-1],
|
||||
use_snake=use_snake,
|
||||
antialias_activation=antialias_activation,
|
||||
use_nearest_upsample=use_nearest_upsample
|
||||
)
|
||||
]
|
||||
|
||||
layers += [
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
||||
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
||||
nn.Tanh() if final_tanh else nn.Identity()
|
||||
]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class AudioOobleckVAE(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=2,
|
||||
channels=128,
|
||||
latent_dim=64,
|
||||
c_mults = [1, 2, 4, 8, 16],
|
||||
strides = [2, 4, 4, 8, 8],
|
||||
use_snake=True,
|
||||
antialias_activation=False,
|
||||
use_nearest_upsample=False,
|
||||
final_tanh=False):
|
||||
super().__init__()
|
||||
self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation)
|
||||
self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation,
|
||||
use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh)
|
||||
self.bottleneck = VAEBottleneck()
|
||||
|
||||
def encode(self, x):
|
||||
return self.bottleneck.encode(self.encoder(x))
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(self.bottleneck.decode(x))
|
||||
|
Reference in New Issue
Block a user