Initial ACE-Step model implementation. (#7972)

This commit is contained in:
comfyanonymous 2025-05-07 05:33:34 -07:00 committed by GitHub
parent 271c9c5b9e
commit 16417b40d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 19738 additions and 4 deletions

View File

@ -466,3 +466,7 @@ class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 1.0188137142395404
class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2

768
comfy/ldm/ace/attention.py Normal file
View File

@ -0,0 +1,768 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/attention.py
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union, Optional
import torch
import torch.nn.functional as F
from torch import nn
import comfy.model_management
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
kv_heads: Optional[int] = None,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
processor=None,
out_dim: int = None,
out_context_dim: int = None,
context_pre_only=None,
pre_only=False,
elementwise_affine: bool = True,
is_causal: bool = False,
dtype=None, device=None, operations=None
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.use_bias = bias
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.is_causal = is_causal
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.only_cross_attention = only_cross_attention
if self.added_kv_proj_dim is None and self.only_cross_attention:
raise ValueError(
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
)
self.group_norm = None
self.spatial_norm = None
self.norm_q = None
self.norm_k = None
self.norm_cross = None
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
else:
self.to_k = None
self.to_v = None
self.added_proj_bias = added_proj_bias
if self.added_kv_proj_dim is not None:
self.add_k_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
self.add_v_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
if self.context_pre_only is not None:
self.add_q_proj = operations.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias, dtype=dtype, device=device)
else:
self.add_q_proj = None
self.add_k_proj = None
self.add_v_proj = None
if not self.pre_only:
self.to_out = nn.ModuleList([])
self.to_out.append(operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device))
self.to_out.append(nn.Dropout(dropout))
else:
self.to_out = None
if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
else:
self.to_add_out = None
self.norm_added_q = None
self.norm_added_k = None
self.processor = processor
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
class CustomLiteLAProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
def __init__(self):
self.kernel_func = nn.ReLU(inplace=False)
self.eps = 1e-15
self.pad_val = 1.0
def apply_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
hidden_states_len = hidden_states.shape[1]
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
if encoder_hidden_states is not None:
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = hidden_states.shape[0]
# `sample` projections.
dtype = hidden_states.dtype
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# `context` projections.
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# attention
if not attn.is_cross_attention:
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
else:
query = hidden_states
key = encoder_hidden_states
value = encoder_hidden_states
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
# RoPE需要 [B, H, S, D] 输入
# 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S])
# Apply query and key normalization if needed
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if rotary_freqs_cis is not None:
query = self.apply_rotary_emb(query, rotary_freqs_cis)
if not attn.is_cross_attention:
key = self.apply_rotary_emb(key, rotary_freqs_cis)
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
# 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
query = query.permute(0, 1, 3, 2) # [B, H, D, S]
if attention_mask is not None:
# attention_mask: [B, S] -> [B, 1, S, 1]
attention_mask = attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S, 1]
query = query * attention_mask.permute(0, 1, 3, 2) # [B, H, S, D] * [B, 1, S, 1]
if not attn.is_cross_attention:
key = key * attention_mask # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
value = value * attention_mask.permute(0, 1, 3, 2) # 如果 value 是 [B, h, D, S]那么需调整mask以匹配S维度
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S_enc, 1]
# 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1]
value = value * encoder_attention_mask.permute(0, 1, 3, 2) # [B, h, D, S_enc] * [B, 1, 1, S_enc]
query = self.kernel_func(query)
key = self.kernel_func(key)
query, key, value = query.float(), key.float(), value.float()
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
vk = torch.matmul(value, key)
hidden_states = torch.matmul(vk, query)
if hidden_states.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.float()
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
hidden_states = hidden_states.to(dtype)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.to(dtype)
# Split the attention outputs.
if encoder_hidden_states is not None and not attn.is_cross_attention and has_encoder_hidden_state_proj:
hidden_states, encoder_hidden_states = (
hidden_states[:, : hidden_states_len],
hidden_states[:, hidden_states_len:],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None and not attn.context_pre_only and not attn.is_cross_attention and hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if encoder_hidden_states is not None and context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if torch.get_autocast_gpu_dtype() == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return hidden_states, encoder_hidden_states
class CustomerAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def apply_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if rotary_freqs_cis is not None:
query = self.apply_rotary_emb(query, rotary_freqs_cis)
if not attn.is_cross_attention:
key = self.apply_rotary_emb(key, rotary_freqs_cis)
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
# attention_mask: N x S1
# encoder_attention_mask: N x S2
# cross attention 整合attention_mask和encoder_attention_mask
combined_mask = attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
attention_mask = attention_mask[:, None, :, :].expand(-1, attn.heads, -1, -1).to(query.dtype)
elif not attn.is_cross_attention and attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
"""Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
if isinstance(x, (list, tuple)):
return list(x)
return [x for _ in range(repeat_time)]
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
"""Return tuple with min_len by repeating element at idx_repeat."""
# convert to list first
x = val2list(x)
# repeat elements if necessary
if len(x) > 0:
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
return tuple(x)
def t2i_modulate(x, shift, scale):
return x * (1 + scale) + shift
def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
if isinstance(kernel_size, tuple):
return tuple([get_same_padding(ks) for ks in kernel_size])
else:
assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
return kernel_size // 2
class ConvLayer(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
kernel_size=3,
stride=1,
dilation=1,
groups=1,
padding: Union[int, None] = None,
use_bias=False,
norm=None,
act=None,
dtype=None, device=None, operations=None
):
super().__init__()
if padding is None:
padding = get_same_padding(kernel_size)
padding *= dilation
self.in_dim = in_dim
self.out_dim = out_dim
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.groups = groups
self.padding = padding
self.use_bias = use_bias
self.conv = operations.Conv1d(
in_dim,
out_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=use_bias,
device=device,
dtype=dtype
)
if norm is not None:
self.norm = operations.RMSNorm(out_dim, elementwise_affine=False, dtype=dtype, device=device)
else:
self.norm = None
if act is not None:
self.act = nn.SiLU(inplace=True)
else:
self.act = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
if self.norm:
x = self.norm(x)
if self.act:
x = self.act(x)
return x
class GLUMBConv(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int,
out_feature=None,
kernel_size=3,
stride=1,
padding: Union[int, None] = None,
use_bias=False,
norm=(None, None, None),
act=("silu", "silu", None),
dilation=1,
dtype=None, device=None, operations=None
):
out_feature = out_feature or in_features
super().__init__()
use_bias = val2tuple(use_bias, 3)
norm = val2tuple(norm, 3)
act = val2tuple(act, 3)
self.glu_act = nn.SiLU(inplace=False)
self.inverted_conv = ConvLayer(
in_features,
hidden_features * 2,
1,
use_bias=use_bias[0],
norm=norm[0],
act=act[0],
dtype=dtype,
device=device,
operations=operations,
)
self.depth_conv = ConvLayer(
hidden_features * 2,
hidden_features * 2,
kernel_size,
stride=stride,
groups=hidden_features * 2,
padding=padding,
use_bias=use_bias[1],
norm=norm[1],
act=None,
dilation=dilation,
dtype=dtype,
device=device,
operations=operations,
)
self.point_conv = ConvLayer(
hidden_features,
out_feature,
1,
use_bias=use_bias[2],
norm=norm[2],
act=act[2],
dtype=dtype,
device=device,
operations=operations,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.transpose(1, 2)
x = self.inverted_conv(x)
x = self.depth_conv(x)
x, gate = torch.chunk(x, 2, dim=1)
gate = self.glu_act(gate)
x = x * gate
x = self.point_conv(x)
x = x.transpose(1, 2)
return x
class LinearTransformerBlock(nn.Module):
"""
A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
use_adaln_single=True,
cross_attention_dim=None,
added_kv_proj_dim=None,
context_pre_only=False,
mlp_ratio=4.0,
add_cross_attention=False,
add_cross_attention_dim=None,
qk_norm=None,
dtype=None, device=None, operations=None
):
super().__init__()
self.norm1 = operations.RMSNorm(dim, elementwise_affine=False, eps=1e-6)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
added_kv_proj_dim=added_kv_proj_dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
qk_norm=qk_norm,
processor=CustomLiteLAProcessor2_0(),
dtype=dtype,
device=device,
operations=operations,
)
self.add_cross_attention = add_cross_attention
self.context_pre_only = context_pre_only
if add_cross_attention and add_cross_attention_dim is not None:
self.cross_attn = Attention(
query_dim=dim,
cross_attention_dim=add_cross_attention_dim,
added_kv_proj_dim=add_cross_attention_dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=context_pre_only,
bias=True,
qk_norm=qk_norm,
processor=CustomerAttnProcessor2_0(),
dtype=dtype,
device=device,
operations=operations,
)
self.norm2 = operations.RMSNorm(dim, 1e-06, elementwise_affine=False)
self.ff = GLUMBConv(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
use_bias=(True, True, False),
norm=(None, None, None),
act=("silu", "silu", None),
dtype=dtype,
device=device,
operations=operations,
)
self.use_adaln_single = use_adaln_single
if use_adaln_single:
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, dtype=dtype, device=device))
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: torch.FloatTensor = None,
encoder_attention_mask: torch.FloatTensor = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
temb: torch.FloatTensor = None,
):
N = hidden_states.shape[0]
# step 1: AdaLN single
if self.use_adaln_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
comfy.model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
if self.use_adaln_single:
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
# step 2: attention
if not self.add_cross_attention:
attn_output, encoder_hidden_states = self.attn(
hidden_states=norm_hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
)
else:
attn_output, _ = self.attn(
hidden_states=norm_hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=None,
)
if self.use_adaln_single:
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if self.add_cross_attention:
attn_output = self.cross_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
)
hidden_states = attn_output + hidden_states
# step 3: add norm
norm_hidden_states = self.norm2(hidden_states)
if self.use_adaln_single:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
# step 4: feed forward
ff_output = self.ff(norm_hidden_states)
if self.use_adaln_single:
ff_output = gate_mlp * ff_output
hidden_states = hidden_states + ff_output
return hidden_states

File diff suppressed because it is too large Load Diff

381
comfy/ldm/ace/model.py Normal file
View File

@ -0,0 +1,381 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/ace_step_transformer.py
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, List, Union
import torch
from torch import nn
import comfy.model_management
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from .attention import LinearTransformerBlock, t2i_modulate
from .lyric_encoder import ConformerEncoder as LyricEncoder
def cross_norm(hidden_states, controlnet_input):
# input N x T x c
mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
return controlnet_input
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, dtype=None, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=device).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
class T2IFinalLayer(nn.Module):
"""
The final layer of Sana.
"""
def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True, dtype=dtype, device=device)
self.scale_shift_table = nn.Parameter(torch.empty(2, hidden_size, dtype=dtype, device=device))
self.out_channels = out_channels
self.patch_size = patch_size
def unpatchfy(
self,
hidden_states: torch.Tensor,
width: int,
):
# 4 unpatchify
new_height, new_width = 1, hidden_states.size(1)
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
).contiguous()
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
).contiguous()
if width > new_width:
output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
elif width < new_width:
output = output[:, :, :, :width]
return output
def forward(self, x, t, output_length):
shift, scale = (comfy.model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
# unpatchify
output = self.unpatchfy(x, output_length)
return output
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
height=16,
width=4096,
patch_size=(16, 1),
in_channels=8,
embed_dim=1152,
bias=True,
dtype=None, device=None, operations=None
):
super().__init__()
patch_size_h, patch_size_w = patch_size
self.early_conv_layers = nn.Sequential(
operations.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias, dtype=dtype, device=device),
operations.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True, dtype=dtype, device=device),
operations.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias, dtype=dtype, device=device)
)
self.patch_size = patch_size
self.height, self.width = height // patch_size_h, width // patch_size_w
self.base_size = self.width
def forward(self, latent):
# early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
latent = self.early_conv_layers(latent)
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
return latent
class ACEStepTransformer2DModel(nn.Module):
# _supports_gradient_checkpointing = True
def __init__(
self,
in_channels: Optional[int] = 8,
num_layers: int = 28,
inner_dim: int = 1536,
attention_head_dim: int = 64,
num_attention_heads: int = 24,
mlp_ratio: float = 4.0,
out_channels: int = 8,
max_position: int = 32768,
rope_theta: float = 1000000.0,
speaker_embedding_dim: int = 512,
text_embedding_dim: int = 768,
ssl_encoder_depths: List[int] = [9, 9],
ssl_names: List[str] = ["mert", "m-hubert"],
ssl_latent_dims: List[int] = [1024, 768],
lyric_encoder_vocab_size: int = 6681,
lyric_hidden_size: int = 1024,
patch_size: List[int] = [16, 1],
max_height: int = 16,
max_width: int = 4096,
audio_model=None,
dtype=None, device=None, operations=None
):
super().__init__()
self.dtype = dtype
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = inner_dim
self.out_channels = out_channels
self.max_position = max_position
self.patch_size = patch_size
self.rope_theta = rope_theta
self.rotary_emb = Qwen2RotaryEmbedding(
dim=self.attention_head_dim,
max_position_embeddings=self.max_position,
base=self.rope_theta,
dtype=dtype,
device=device,
)
# 2. Define input layers
self.in_channels = in_channels
self.num_layers = num_layers
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
LinearTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.num_attention_heads,
attention_head_dim=attention_head_dim,
mlp_ratio=mlp_ratio,
add_cross_attention=True,
add_cross_attention_dim=self.inner_dim,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(self.num_layers)
]
)
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim, dtype=dtype, device=device, operations=operations)
self.t_block = nn.Sequential(nn.SiLU(), operations.Linear(self.inner_dim, 6 * self.inner_dim, bias=True, dtype=dtype, device=device))
# speaker
self.speaker_embedder = operations.Linear(speaker_embedding_dim, self.inner_dim, dtype=dtype, device=device)
# genre
self.genre_embedder = operations.Linear(text_embedding_dim, self.inner_dim, dtype=dtype, device=device)
# lyric
self.lyric_embs = operations.Embedding(lyric_encoder_vocab_size, lyric_hidden_size, dtype=dtype, device=device)
self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0, dtype=dtype, device=device, operations=operations)
self.lyric_proj = operations.Linear(lyric_hidden_size, self.inner_dim, dtype=dtype, device=device)
projector_dim = 2 * self.inner_dim
self.projectors = nn.ModuleList([
nn.Sequential(
operations.Linear(self.inner_dim, projector_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(projector_dim, projector_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(projector_dim, ssl_dim, dtype=dtype, device=device),
) for ssl_dim in ssl_latent_dims
])
self.proj_in = PatchEmbed(
height=max_height,
width=max_width,
patch_size=patch_size,
embed_dim=self.inner_dim,
bias=True,
dtype=dtype,
device=device,
operations=operations,
)
self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels, dtype=dtype, device=device, operations=operations)
def forward_lyric_encoder(
self,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
out_dtype=None,
):
# N x T x D
lyric_embs = self.lyric_embs(lyric_token_idx, out_dtype=out_dtype)
prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
return prompt_prenet_out
def encode(
self,
encoder_text_hidden_states: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
):
bs = encoder_text_hidden_states.shape[0]
device = encoder_text_hidden_states.device
# speaker embedding
encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
# genre embedding
encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
# lyric
encoder_lyric_hidden_states = self.forward_lyric_encoder(
lyric_token_idx=lyric_token_idx,
lyric_mask=lyric_mask,
out_dtype=encoder_text_hidden_states.dtype,
)
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
encoder_hidden_mask = None
if text_attention_mask is not None:
speaker_mask = torch.ones(bs, 1, device=device)
encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
return encoder_hidden_states, encoder_hidden_mask
def decode(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_mask: torch.Tensor,
timestep: Optional[torch.Tensor],
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
return_dict: bool = True,
):
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
temb = self.t_block(embedded_timestep)
hidden_states = self.proj_in(hidden_states)
# controlnet logic
if block_controlnet_hidden_states is not None:
control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
hidden_states = hidden_states + control_condi * controlnet_scale
# inner_hidden_states = []
rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
for index_block, block in enumerate(self.transformer_blocks):
hidden_states = block(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_hidden_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
temb=temb,
)
output = self.final_layer(hidden_states, embedded_timestep, output_length)
return output
def forward(
self,
x,
timestep,
attention_mask=None,
context: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
**kwargs
):
hidden_states = x
encoder_text_hidden_states = context
encoder_hidden_states, encoder_hidden_mask = self.encode(
encoder_text_hidden_states=encoder_text_hidden_states,
text_attention_mask=text_attention_mask,
speaker_embeds=speaker_embeds,
lyric_token_idx=lyric_token_idx,
lyric_mask=lyric_mask,
)
output_length = hidden_states.shape[-1]
output = self.decode(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_mask=encoder_hidden_mask,
timestep=timestep,
output_length=output_length,
block_controlnet_hidden_states=block_controlnet_hidden_states,
controlnet_scale=controlnet_scale,
)
return output

View File

@ -0,0 +1,644 @@
# Rewritten from diffusers
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Union
import comfy.model_management
import comfy.ops
ops = comfy.ops.disable_weight_init
class RMSNorm(ops.RMSNorm):
def __init__(self, dim, eps=1e-5, elementwise_affine=True, bias=False):
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
if elementwise_affine:
self.bias = nn.Parameter(torch.empty(dim)) if bias else None
def forward(self, x):
x = super().forward(x)
if self.elementwise_affine:
if self.bias is not None:
x = x + comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device)
return x
def get_normalization(norm_type, num_features, num_groups=32, eps=1e-5):
if norm_type == "batch_norm":
return nn.BatchNorm2d(num_features)
elif norm_type == "group_norm":
return ops.GroupNorm(num_groups, num_features)
elif norm_type == "layer_norm":
return ops.LayerNorm(num_features)
elif norm_type == "rms_norm":
return RMSNorm(num_features, eps=eps, elementwise_affine=True, bias=True)
else:
raise ValueError(f"Unknown normalization type: {norm_type}")
def get_activation(activation_type):
if activation_type == "relu":
return nn.ReLU()
elif activation_type == "relu6":
return nn.ReLU6()
elif activation_type == "silu":
return nn.SiLU()
elif activation_type == "leaky_relu":
return nn.LeakyReLU(0.2)
else:
raise ValueError(f"Unknown activation type: {activation_type}")
class ResBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
norm_type: str = "batch_norm",
act_fn: str = "relu6",
) -> None:
super().__init__()
self.norm_type = norm_type
self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
self.conv1 = ops.Conv2d(in_channels, in_channels, 3, 1, 1)
self.conv2 = ops.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
self.norm = get_normalization(norm_type, out_channels)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.conv1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.norm_type == "rms_norm":
# move channel to the last dimension so we apply RMSnorm across channel dimension
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
else:
hidden_states = self.norm(hidden_states)
return hidden_states + residual
class SanaMultiscaleAttentionProjection(nn.Module):
def __init__(
self,
in_channels: int,
num_attention_heads: int,
kernel_size: int,
) -> None:
super().__init__()
channels = 3 * in_channels
self.proj_in = ops.Conv2d(
channels,
channels,
kernel_size,
padding=kernel_size // 2,
groups=channels,
bias=False,
)
self.proj_out = ops.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_out(hidden_states)
return hidden_states
class SanaMultiscaleLinearAttention(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_attention_heads: int = None,
attention_head_dim: int = 8,
mult: float = 1.0,
norm_type: str = "batch_norm",
kernel_sizes: tuple = (5,),
eps: float = 1e-15,
residual_connection: bool = False,
):
super().__init__()
self.eps = eps
self.attention_head_dim = attention_head_dim
self.norm_type = norm_type
self.residual_connection = residual_connection
num_attention_heads = (
int(in_channels // attention_head_dim * mult)
if num_attention_heads is None
else num_attention_heads
)
inner_dim = num_attention_heads * attention_head_dim
self.to_q = ops.Linear(in_channels, inner_dim, bias=False)
self.to_k = ops.Linear(in_channels, inner_dim, bias=False)
self.to_v = ops.Linear(in_channels, inner_dim, bias=False)
self.to_qkv_multiscale = nn.ModuleList()
for kernel_size in kernel_sizes:
self.to_qkv_multiscale.append(
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
)
self.nonlinearity = nn.ReLU()
self.to_out = ops.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
self.norm_out = get_normalization(norm_type, out_channels)
def apply_linear_attention(self, query, key, value):
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1)
scores = torch.matmul(value, key.transpose(-1, -2))
hidden_states = torch.matmul(scores, query)
hidden_states = hidden_states.to(dtype=torch.float32)
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
return hidden_states
def apply_quadratic_attention(self, query, key, value):
scores = torch.matmul(key.transpose(-1, -2), query)
scores = scores.to(dtype=torch.float32)
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
hidden_states = torch.matmul(value, scores.to(value.dtype))
return hidden_states
def forward(self, hidden_states):
height, width = hidden_states.shape[-2:]
if height * width > self.attention_head_dim:
use_linear_attention = True
else:
use_linear_attention = False
residual = hidden_states
batch_size, _, height, width = list(hidden_states.size())
original_dtype = hidden_states.dtype
hidden_states = hidden_states.movedim(1, -1)
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
hidden_states = torch.cat([query, key, value], dim=3)
hidden_states = hidden_states.movedim(-1, 1)
multi_scale_qkv = [hidden_states]
for block in self.to_qkv_multiscale:
multi_scale_qkv.append(block(hidden_states))
hidden_states = torch.cat(multi_scale_qkv, dim=1)
if use_linear_attention:
# for linear attention upcast hidden_states to float32
hidden_states = hidden_states.to(dtype=torch.float32)
hidden_states = hidden_states.reshape(batch_size, -1, 3 * self.attention_head_dim, height * width)
query, key, value = hidden_states.chunk(3, dim=2)
query = self.nonlinearity(query)
key = self.nonlinearity(key)
if use_linear_attention:
hidden_states = self.apply_linear_attention(query, key, value)
hidden_states = hidden_states.to(dtype=original_dtype)
else:
hidden_states = self.apply_quadratic_attention(query, key, value)
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
hidden_states = self.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.norm_type == "rms_norm":
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
else:
hidden_states = self.norm_out(hidden_states)
if self.residual_connection:
hidden_states = hidden_states + residual
return hidden_states
class EfficientViTBlock(nn.Module):
def __init__(
self,
in_channels: int,
mult: float = 1.0,
attention_head_dim: int = 32,
qkv_multiscales: tuple = (5,),
norm_type: str = "batch_norm",
) -> None:
super().__init__()
self.attn = SanaMultiscaleLinearAttention(
in_channels=in_channels,
out_channels=in_channels,
mult=mult,
attention_head_dim=attention_head_dim,
norm_type=norm_type,
kernel_sizes=qkv_multiscales,
residual_connection=True,
)
self.conv_out = GLUMBConv(
in_channels=in_channels,
out_channels=in_channels,
norm_type="rms_norm",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.attn(x)
x = self.conv_out(x)
return x
class GLUMBConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
expand_ratio: float = 4,
norm_type: str = None,
residual_connection: bool = True,
) -> None:
super().__init__()
hidden_channels = int(expand_ratio * in_channels)
self.norm_type = norm_type
self.residual_connection = residual_connection
self.nonlinearity = nn.SiLU()
self.conv_inverted = ops.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
self.conv_depth = ops.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
self.conv_point = ops.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
self.norm = None
if norm_type == "rms_norm":
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.residual_connection:
residual = hidden_states
hidden_states = self.conv_inverted(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv_depth(hidden_states)
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
hidden_states = hidden_states * self.nonlinearity(gate)
hidden_states = self.conv_point(hidden_states)
if self.norm_type == "rms_norm":
# move channel to the last dimension so we apply RMSnorm across channel dimension
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.residual_connection:
hidden_states = hidden_states + residual
return hidden_states
def get_block(
block_type: str,
in_channels: int,
out_channels: int,
attention_head_dim: int,
norm_type: str,
act_fn: str,
qkv_mutliscales: tuple = (),
):
if block_type == "ResBlock":
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
elif block_type == "EfficientViTBlock":
block = EfficientViTBlock(
in_channels,
attention_head_dim=attention_head_dim,
norm_type=norm_type,
qkv_multiscales=qkv_mutliscales
)
else:
raise ValueError(f"Block with {block_type=} is not supported.")
return block
class DCDownBlock2d(nn.Module):
def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None:
super().__init__()
self.downsample = downsample
self.factor = 2
self.stride = 1 if downsample else 2
self.group_size = in_channels * self.factor**2 // out_channels
self.shortcut = shortcut
out_ratio = self.factor**2
if downsample:
assert out_channels % out_ratio == 0
out_channels = out_channels // out_ratio
self.conv = ops.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=self.stride,
padding=1,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = self.conv(hidden_states)
if self.downsample:
x = F.pixel_unshuffle(x, self.factor)
if self.shortcut:
y = F.pixel_unshuffle(hidden_states, self.factor)
y = y.unflatten(1, (-1, self.group_size))
y = y.mean(dim=2)
hidden_states = x + y
else:
hidden_states = x
return hidden_states
class DCUpBlock2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
interpolate: bool = False,
shortcut: bool = True,
interpolation_mode: str = "nearest",
) -> None:
super().__init__()
self.interpolate = interpolate
self.interpolation_mode = interpolation_mode
self.shortcut = shortcut
self.factor = 2
self.repeats = out_channels * self.factor**2 // in_channels
out_ratio = self.factor**2
if not interpolate:
out_channels = out_channels * out_ratio
self.conv = ops.Conv2d(in_channels, out_channels, 3, 1, 1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.interpolate:
x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
x = self.conv(x)
else:
x = self.conv(hidden_states)
x = F.pixel_shuffle(x, self.factor)
if self.shortcut:
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
y = F.pixel_shuffle(y, self.factor)
hidden_states = x + y
else:
hidden_states = x
return hidden_states
class Encoder(nn.Module):
def __init__(
self,
in_channels: int,
latent_channels: int,
attention_head_dim: int = 32,
block_type: str or tuple = "ResBlock",
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
downsample_block_type: str = "pixel_unshuffle",
out_shortcut: bool = True,
):
super().__init__()
num_blocks = len(block_out_channels)
if isinstance(block_type, str):
block_type = (block_type,) * num_blocks
if layers_per_block[0] > 0:
self.conv_in = ops.Conv2d(
in_channels,
block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
kernel_size=3,
stride=1,
padding=1,
)
else:
self.conv_in = DCDownBlock2d(
in_channels=in_channels,
out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
downsample=downsample_block_type == "pixel_unshuffle",
shortcut=False,
)
down_blocks = []
for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)):
down_block_list = []
for _ in range(num_layers):
block = get_block(
block_type[i],
out_channel,
out_channel,
attention_head_dim=attention_head_dim,
norm_type="rms_norm",
act_fn="silu",
qkv_mutliscales=qkv_multiscales[i],
)
down_block_list.append(block)
if i < num_blocks - 1 and num_layers > 0:
downsample_block = DCDownBlock2d(
in_channels=out_channel,
out_channels=block_out_channels[i + 1],
downsample=downsample_block_type == "pixel_unshuffle",
shortcut=True,
)
down_block_list.append(downsample_block)
down_blocks.append(nn.Sequential(*down_block_list))
self.down_blocks = nn.ModuleList(down_blocks)
self.conv_out = ops.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1)
self.out_shortcut = out_shortcut
if out_shortcut:
self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
if self.out_shortcut:
x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size))
x = x.mean(dim=2)
hidden_states = self.conv_out(hidden_states) + x
else:
hidden_states = self.conv_out(hidden_states)
return hidden_states
class Decoder(nn.Module):
def __init__(
self,
in_channels: int,
latent_channels: int,
attention_head_dim: int = 32,
block_type: str or tuple = "ResBlock",
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
norm_type: str or tuple = "rms_norm",
act_fn: str or tuple = "silu",
upsample_block_type: str = "pixel_shuffle",
in_shortcut: bool = True,
):
super().__init__()
num_blocks = len(block_out_channels)
if isinstance(block_type, str):
block_type = (block_type,) * num_blocks
if isinstance(norm_type, str):
norm_type = (norm_type,) * num_blocks
if isinstance(act_fn, str):
act_fn = (act_fn,) * num_blocks
self.conv_in = ops.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1)
self.in_shortcut = in_shortcut
if in_shortcut:
self.in_shortcut_repeats = block_out_channels[-1] // latent_channels
up_blocks = []
for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
up_block_list = []
if i < num_blocks - 1 and num_layers > 0:
upsample_block = DCUpBlock2d(
block_out_channels[i + 1],
out_channel,
interpolate=upsample_block_type == "interpolate",
shortcut=True,
)
up_block_list.append(upsample_block)
for _ in range(num_layers):
block = get_block(
block_type[i],
out_channel,
out_channel,
attention_head_dim=attention_head_dim,
norm_type=norm_type[i],
act_fn=act_fn[i],
qkv_mutliscales=qkv_multiscales[i],
)
up_block_list.append(block)
up_blocks.insert(0, nn.Sequential(*up_block_list))
self.up_blocks = nn.ModuleList(up_blocks)
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
self.conv_act = nn.ReLU()
self.conv_out = None
if layers_per_block[0] > 0:
self.conv_out = ops.Conv2d(channels, in_channels, 3, 1, 1)
else:
self.conv_out = DCUpBlock2d(
channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.in_shortcut:
x = hidden_states.repeat_interleave(
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
)
hidden_states = self.conv_in(hidden_states) + x
else:
hidden_states = self.conv_in(hidden_states)
for up_block in reversed(self.up_blocks):
hidden_states = up_block(hidden_states)
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderDC(nn.Module):
def __init__(
self,
in_channels: int = 2,
latent_channels: int = 8,
attention_head_dim: int = 32,
encoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
decoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
encoder_layers_per_block: Tuple[int] = (2, 2, 3, 3),
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3),
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
upsample_block_type: str = "interpolate",
downsample_block_type: str = "Conv",
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
decoder_act_fns: Union[str, Tuple[str]] = "silu",
scaling_factor: float = 0.41407,
) -> None:
super().__init__()
self.encoder = Encoder(
in_channels=in_channels,
latent_channels=latent_channels,
attention_head_dim=attention_head_dim,
block_type=encoder_block_types,
block_out_channels=encoder_block_out_channels,
layers_per_block=encoder_layers_per_block,
qkv_multiscales=encoder_qkv_multiscales,
downsample_block_type=downsample_block_type,
)
self.decoder = Decoder(
in_channels=in_channels,
latent_channels=latent_channels,
attention_head_dim=attention_head_dim,
block_type=decoder_block_types,
block_out_channels=decoder_block_out_channels,
layers_per_block=decoder_layers_per_block,
qkv_multiscales=decoder_qkv_multiscales,
norm_type=decoder_norm_types,
act_fn=decoder_act_fns,
upsample_block_type=upsample_block_type,
)
self.scaling_factor = scaling_factor
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Internal encoding function."""
encoded = self.encoder(x)
return encoded * self.scaling_factor
def decode(self, z: torch.Tensor) -> torch.Tensor:
# Scale the latents back
z = z / self.scaling_factor
decoded = self.decoder(z)
return decoded
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = self.encode(x)
return self.decode(z)

View File

@ -0,0 +1,104 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py
import torch
from .autoencoder_dc import AutoencoderDC
import torchaudio
import torchvision.transforms as transforms
from .music_vocoder import ADaMoSHiFiGANV1
class MusicDCAE(torch.nn.Module):
def __init__(self, source_sample_rate=None, dcae_config={}, vocoder_config={}):
super(MusicDCAE, self).__init__()
self.dcae = AutoencoderDC(**dcae_config)
self.vocoder = ADaMoSHiFiGANV1(**vocoder_config)
if source_sample_rate is None:
self.source_sample_rate = 48000
else:
self.source_sample_rate = source_sample_rate
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
self.transform = transforms.Compose([
transforms.Normalize(0.5, 0.5),
])
self.min_mel_value = -11.0
self.max_mel_value = 3.0
self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
self.mel_chunk_size = 1024
self.time_dimention_multiple = 8
self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
self.scale_factor = 0.1786
self.shift_factor = -1.9091
def load_audio(self, audio_path):
audio, sr = torchaudio.load(audio_path)
return audio, sr
def forward_mel(self, audios):
mels = []
for i in range(len(audios)):
image = self.vocoder.mel_transform(audios[i])
mels.append(image)
mels = torch.stack(mels)
return mels
@torch.no_grad()
def encode(self, audios, audio_lengths=None, sr=None):
if audio_lengths is None:
audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
audio_lengths = audio_lengths.to(audios.device)
if sr is None:
sr = self.source_sample_rate
if sr != 44100:
audios = torchaudio.functional.resample(audios, sr, 44100)
max_audio_len = audios.shape[-1]
if max_audio_len % (8 * 512) != 0:
audios = torch.nn.functional.pad(audios, (0, 8 * 512 - max_audio_len % (8 * 512)))
mels = self.forward_mel(audios)
mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
mels = self.transform(mels)
latents = []
for mel in mels:
latent = self.dcae.encoder(mel.unsqueeze(0))
latents.append(latent)
latents = torch.cat(latents, dim=0)
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
latents = (latents - self.shift_factor) * self.scale_factor
return latents
# return latents, latent_lengths
@torch.no_grad()
def decode(self, latents, audio_lengths=None, sr=None):
latents = latents / self.scale_factor + self.shift_factor
pred_wavs = []
for latent in latents:
mels = self.dcae.decoder(latent.unsqueeze(0))
mels = mels * 0.5 + 0.5
mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
wav = self.vocoder.decode(mels[0]).squeeze(1)
if sr is not None:
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
wav = torchaudio.functional.resample(wav, 44100, sr)
# wav = resampler(wav)
else:
sr = 44100
pred_wavs.append(wav)
if audio_lengths is not None:
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
return torch.stack(pred_wavs)
# return sr, pred_wavs
def forward(self, audios, audio_lengths=None, sr=None):
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
return sr, pred_wavs, latents, latent_lengths

View File

@ -0,0 +1,108 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py
import torch
import torch.nn as nn
from torch import Tensor
from torchaudio.transforms import MelScale
import comfy.model_management
class LinearSpectrogram(nn.Module):
def __init__(
self,
n_fft=2048,
win_length=2048,
hop_length=512,
center=False,
mode="pow2_sqrt",
):
super().__init__()
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.mode = mode
self.register_buffer("window", torch.hann_window(win_length))
def forward(self, y: Tensor) -> Tensor:
if y.ndim == 3:
y = y.squeeze(1)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(
(self.win_length - self.hop_length) // 2,
(self.win_length - self.hop_length + 1) // 2,
),
mode="reflect",
).squeeze(1)
dtype = y.dtype
spec = torch.stft(
y.float(),
self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=comfy.model_management.cast_to(self.window, dtype=torch.float32, device=y.device),
center=self.center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.view_as_real(spec)
if self.mode == "pow2_sqrt":
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = spec.to(dtype)
return spec
class LogMelSpectrogram(nn.Module):
def __init__(
self,
sample_rate=44100,
n_fft=2048,
win_length=2048,
hop_length=512,
n_mels=128,
center=False,
f_min=0.0,
f_max=None,
):
super().__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.n_mels = n_mels
self.f_min = f_min
self.f_max = f_max or sample_rate // 2
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
self.mel_scale = MelScale(
self.n_mels,
self.sample_rate,
self.f_min,
self.f_max,
self.n_fft // 2 + 1,
"slaney",
"slaney",
)
def compress(self, x: Tensor) -> Tensor:
return torch.log(torch.clamp(x, min=1e-5))
def decompress(self, x: Tensor) -> Tensor:
return torch.exp(x)
def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
linear = self.spectrogram(x)
x = self.mel_scale(linear)
x = self.compress(x)
# print(x.shape)
if return_linear:
return x, self.compress(linear)
return x

View File

@ -0,0 +1,542 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_vocoder.py
import torch
from torch import nn
from functools import partial
from math import prod
from typing import Callable, Tuple, List
import numpy as np
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
# from diffusers.models.modeling_utils import ModelMixin
# from diffusers.loaders import FromOriginalModelMixin
# from diffusers.configuration_utils import ConfigMixin, register_to_config
from .music_log_mel import LogMelSpectrogram
import comfy.model_management
import comfy.ops
ops = comfy.ops.disable_weight_init
def drop_path(
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
""" # noqa: E501
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f"drop_prob={round(self.drop_prob,3):0.3f}"
class LayerNorm(nn.Module):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
""" # noqa: E501
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(
x, self.normalized_shape, comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device), comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device), self.eps
)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = comfy.model_management.cast_to(self.weight[:, None], dtype=x.dtype, device=x.device) * x + comfy.model_management.cast_to(self.bias[:, None], dtype=x.dtype, device=x.device)
return x
class ConvNeXtBlock(nn.Module):
r"""ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
kernel_size (int): Kernel size for depthwise conv. Default: 7.
dilation (int): Dilation for depthwise conv. Default: 1.
""" # noqa: E501
def __init__(
self,
dim: int,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-6,
mlp_ratio: float = 4.0,
kernel_size: int = 7,
dilation: int = 1,
):
super().__init__()
self.dwconv = ops.Conv1d(
dim,
dim,
kernel_size=kernel_size,
padding=int(dilation * (kernel_size - 1) / 2),
groups=dim,
) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = ops.Linear(
dim, int(mlp_ratio * dim)
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = ops.Linear(int(mlp_ratio * dim), dim)
self.gamma = (
nn.Parameter(torch.empty((dim)), requires_grad=False)
if layer_scale_init_value > 0
else None
)
self.drop_path = DropPath(
drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x, apply_residual: bool = True):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) * x
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
x = self.drop_path(x)
if apply_residual:
x = input + x
return x
class ParallelConvNeXtBlock(nn.Module):
def __init__(self, kernel_sizes: List[int], *args, **kwargs):
super().__init__()
self.blocks = nn.ModuleList(
[
ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
for kernel_size in kernel_sizes
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.stack(
[block(x, apply_residual=False) for block in self.blocks] + [x],
dim=1,
).sum(dim=1)
class ConvNeXtEncoder(nn.Module):
def __init__(
self,
input_channels=3,
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
drop_path_rate=0.0,
layer_scale_init_value=1e-6,
kernel_sizes: Tuple[int] = (7,),
):
super().__init__()
assert len(depths) == len(dims)
self.channel_layers = nn.ModuleList()
stem = nn.Sequential(
ops.Conv1d(
input_channels,
dims[0],
kernel_size=7,
padding=3,
padding_mode="replicate",
),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
)
self.channel_layers.append(stem)
for i in range(len(depths) - 1):
mid_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
ops.Conv1d(dims[i], dims[i + 1], kernel_size=1),
)
self.channel_layers.append(mid_layer)
block_fn = (
partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
if len(kernel_sizes) == 1
else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
)
self.stages = nn.ModuleList()
drop_path_rates = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
]
cur = 0
for i in range(len(depths)):
stage = nn.Sequential(
*[
block_fn(
dim=dims[i],
drop_path=drop_path_rates[cur + j],
layer_scale_init_value=layer_scale_init_value,
)
for j in range(depths[i])
]
)
self.stages.append(stage)
cur += depths[i]
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
for channel_layer, stage in zip(self.channel_layers, self.stages):
x = channel_layer(x)
x = stage(x)
return self.norm(x)
def get_padding(kernel_size, dilation=1):
return (kernel_size * dilation - dilation) // 2
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super().__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs2 = nn.ModuleList(
[
weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.silu(x)
xt = c1(xt)
xt = F.silu(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for conv in self.convs1:
remove_weight_norm(conv)
for conv in self.convs2:
remove_weight_norm(conv)
class HiFiGANGenerator(nn.Module):
def __init__(
self,
*,
hop_length: int = 512,
upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
resblock_dilation_sizes: Tuple[Tuple[int]] = (
(1, 3, 5), (1, 3, 5), (1, 3, 5)),
num_mels: int = 128,
upsample_initial_channel: int = 512,
use_template: bool = True,
pre_conv_kernel_size: int = 7,
post_conv_kernel_size: int = 7,
post_activation: Callable = partial(nn.SiLU, inplace=True),
):
super().__init__()
assert (
prod(upsample_rates) == hop_length
), f"hop_length must be {prod(upsample_rates)}"
self.conv_pre = weight_norm(
ops.Conv1d(
num_mels,
upsample_initial_channel,
pre_conv_kernel_size,
1,
padding=get_padding(pre_conv_kernel_size),
)
)
self.num_upsamples = len(upsample_rates)
self.num_kernels = len(resblock_kernel_sizes)
self.noise_convs = nn.ModuleList()
self.use_template = use_template
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
c_cur = upsample_initial_channel // (2 ** (i + 1))
self.ups.append(
weight_norm(
ops.ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
if not use_template:
continue
if i + 1 < len(upsample_rates):
stride_f0 = np.prod(upsample_rates[i + 1:])
self.noise_convs.append(
ops.Conv1d(
1,
c_cur,
kernel_size=stride_f0 * 2,
stride=stride_f0,
padding=stride_f0 // 2,
)
)
else:
self.noise_convs.append(ops.Conv1d(1, c_cur, kernel_size=1))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
self.resblocks.append(ResBlock1(ch, k, d))
self.activation_post = post_activation()
self.conv_post = weight_norm(
ops.Conv1d(
ch,
1,
post_conv_kernel_size,
1,
padding=get_padding(post_conv_kernel_size),
)
)
def forward(self, x, template=None):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.silu(x, inplace=True)
x = self.ups[i](x)
if self.use_template:
x = x + self.noise_convs[i](template)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
for up in self.ups:
remove_weight_norm(up)
for block in self.resblocks:
block.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class ADaMoSHiFiGANV1(nn.Module):
def __init__(
self,
input_channels: int = 128,
depths: List[int] = [3, 3, 9, 3],
dims: List[int] = [128, 256, 384, 512],
drop_path_rate: float = 0.0,
kernel_sizes: Tuple[int] = (7,),
upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
resblock_dilation_sizes: Tuple[Tuple[int]] = (
(1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
num_mels: int = 512,
upsample_initial_channel: int = 1024,
use_template: bool = False,
pre_conv_kernel_size: int = 13,
post_conv_kernel_size: int = 13,
sampling_rate: int = 44100,
n_fft: int = 2048,
win_length: int = 2048,
hop_length: int = 512,
f_min: int = 40,
f_max: int = 16000,
n_mels: int = 128,
):
super().__init__()
self.backbone = ConvNeXtEncoder(
input_channels=input_channels,
depths=depths,
dims=dims,
drop_path_rate=drop_path_rate,
kernel_sizes=kernel_sizes,
)
self.head = HiFiGANGenerator(
hop_length=hop_length,
upsample_rates=upsample_rates,
upsample_kernel_sizes=upsample_kernel_sizes,
resblock_kernel_sizes=resblock_kernel_sizes,
resblock_dilation_sizes=resblock_dilation_sizes,
num_mels=num_mels,
upsample_initial_channel=upsample_initial_channel,
use_template=use_template,
pre_conv_kernel_size=pre_conv_kernel_size,
post_conv_kernel_size=post_conv_kernel_size,
)
self.sampling_rate = sampling_rate
self.mel_transform = LogMelSpectrogram(
sample_rate=sampling_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
f_min=f_min,
f_max=f_max,
n_mels=n_mels,
)
self.eval()
@torch.no_grad()
def decode(self, mel):
y = self.backbone(mel)
y = self.head(y)
return y
@torch.no_grad()
def encode(self, x):
return self.mel_transform(x)
def forward(self, mel):
y = self.backbone(mel)
y = self.head(y)
return y

View File

@ -39,6 +39,7 @@ import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.ace.model
import comfy.model_management
import comfy.patcher_extension
@ -1121,3 +1122,21 @@ class Chroma(Flux):
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class ACEStep(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
noise = kwargs.get("noise", None)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
if cross_attn is not None:
out['lyric_token_idx'] = comfy.conds.CONDRegular(conditioning_lyrics)
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
return out

View File

@ -226,6 +226,31 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
return dit_config
if '{}genre_embedder.weight'.format(key_prefix) in state_dict_keys: #ACE-Step model
dit_config = {}
dit_config["audio_model"] = "ace"
dit_config["attention_head_dim"] = 128
dit_config["in_channels"] = 8
dit_config["inner_dim"] = 2560
dit_config["max_height"] = 16
dit_config["max_position"] = 32768
dit_config["max_width"] = 32768
dit_config["mlp_ratio"] = 2.5
dit_config["num_attention_heads"] = 20
dit_config["num_layers"] = 24
dit_config["out_channels"] = 8
dit_config["patch_size"] = [16, 1]
dit_config["rope_theta"] = 1000000.0
dit_config["speaker_embedding_dim"] = 512
dit_config["text_embedding_dim"] = 768
dit_config["ssl_encoder_depths"] = [8, 8]
dit_config["ssl_latent_dims"] = [1024, 768]
dit_config["ssl_names"] = ["mert", "m-hubert"]
dit_config["lyric_encoder_vocab_size"] = 6693
dit_config["lyric_hidden_size"] = 1024
return dit_config
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
patch_size = 2
dit_config = {}

View File

@ -15,6 +15,7 @@ import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import yaml
import math
@ -42,6 +43,7 @@ import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.model_patcher
import comfy.lora
@ -437,6 +439,19 @@ class VAE:
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 300) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 72000) * model_management.dtype_size(dtype)
self.latent_channels = 8
self.output_channels = 2
# self.upscale_ratio = 2048
# self.downscale_ratio = 2048
self.latent_dim = 2
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = True
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -715,6 +730,7 @@ class CLIPType(Enum):
WAN = 13
HIDREAM = 14
CHROMA = 15
ACE = 16
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@ -840,6 +856,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
elif te_model == TEModel.T5_BASE:
if clip_type == CLIPType.ACE or "spiece_model" in clip_data[0]:
clip_target.clip = comfy.text_encoders.ace.AceT5Model
clip_target.tokenizer = comfy.text_encoders.ace.AceT5Tokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
else:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
elif te_model == TEModel.GEMMA_2_2B:

View File

@ -17,6 +17,7 @@ import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.ace
from . import supported_models_base
from . import latent_formats
@ -1100,6 +1101,34 @@ class Chroma(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma]
class ACEStep(supported_models_base.BASE):
unet_config = {
"audio_model": "ace",
}
unet_extra_config = {
}
sampling_settings = {
"shift": 3.0,
}
latent_format = comfy.latent_formats.ACEAudio
memory_usage_factor = 0.5
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.ACEStep(self, device=device)
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
models += [SVD_img2vid]

145
comfy/text_encoders/ace.py Normal file
View File

@ -0,0 +1,145 @@
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.t5
import os
import re
import torch
import logging
from tokenizers import Tokenizer
from .ace_text_cleaners import multilingual_cleaners
SUPPORT_LANGUAGES = {
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
"pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
"nl": 297, "ar": 5022, "zh": 5023, "ja": 5412, "hu": 5753,
"ko": 6152, "hi": 6680
}
structure_pattern = re.compile(r"\[.*?\]")
DEFAULT_VOCAB_FILE = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
class VoiceBpeTokenizer:
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
self.tokenizer = None
if vocab_file is not None:
self.tokenizer = Tokenizer.from_file(vocab_file)
def preprocess_text(self, txt, lang):
txt = multilingual_cleaners(txt, lang)
return txt
def encode(self, txt, lang='en'):
# lang = lang.split("-")[0] # remove the region
# self.check_input_length(txt, lang)
txt = self.preprocess_text(txt, lang)
lang = "zh-cn" if lang == "zh" else lang
txt = f"[{lang}]{txt}"
txt = txt.replace(" ", "[SPACE]")
return self.tokenizer.encode(txt).ids
def get_lang(self, line):
if line.startswith("[") and line[3:4] == ']':
lang = line[1:3].lower()
if lang in SUPPORT_LANGUAGES:
return lang, line[4:]
return "en", line
def __call__(self, string):
lines = string.split("\n")
lyric_token_idx = [261]
for line in lines:
line = line.strip()
if not line:
lyric_token_idx += [2]
continue
lang, line = self.get_lang(line)
if lang not in SUPPORT_LANGUAGES:
lang = "en"
if "zh" in lang:
lang = "zh"
if "spa" in lang:
lang = "es"
try:
if structure_pattern.match(line):
token_idx = self.encode(line, "en")
else:
token_idx = self.encode(line, lang)
lyric_token_idx = lyric_token_idx + token_idx + [2]
except Exception as e:
logging.warning("tokenize error {} for line {} major_language {}".format(e, line, lang))
return {"input_ids": lyric_token_idx}
@staticmethod
def from_pretrained(path, **kwargs):
return VoiceBpeTokenizer(path, **kwargs)
def get_vocab(self):
return {}
class UMT5BaseModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_base.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=False, model_options=model_options)
class UMT5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=768, embedding_key='umt5base', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LyricsTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='lyrics', tokenizer_class=VoiceBpeTokenizer, has_start_token=True, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=2, has_end_token=False, tokenizer_data=tokenizer_data)
class AceT5Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.voicebpe = LyricsTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.umt5base = UMT5BaseTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["lyrics"] = self.voicebpe.tokenize_with_weights(kwargs.get("lyrics", ""), return_word_ids, **kwargs)
out["umt5base"] = self.umt5base.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
return self.umt5base.untokenize(token_weight_pair)
def state_dict(self):
return self.umt5base.state_dict()
class AceT5Model(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__()
self.umt5base = UMT5BaseModel(device=device, dtype=dtype, model_options=model_options)
self.dtypes = set()
if dtype is not None:
self.dtypes.add(dtype)
def set_clip_options(self, options):
self.umt5base.set_clip_options(options)
def reset_clip_options(self):
self.umt5base.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_umt5base = token_weight_pairs["umt5base"]
token_weight_pairs_lyrics = token_weight_pairs["lyrics"]
t5_out, t5_pooled = self.umt5base.encode_token_weights(token_weight_pairs_umt5base)
lyrics_embeds = torch.tensor(list(map(lambda a: a[0], token_weight_pairs_lyrics[0]))).unsqueeze(0)
return t5_out, None, {"conditioning_lyrics": lyrics_embeds}
def load_sd(self, sd):
return self.umt5base.load_sd(sd)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,270 @@
# basic text cleaners for the ACE step model
# I didn't copy the ones from the reference code because I didn't want to deal with the dependencies
# TODO: more languages than english?
import re
def number_to_text(num, ordinal=False):
"""
Convert a number (int or float) to its text representation.
Args:
num: The number to convert
Returns:
str: Text representation of the number
"""
if not isinstance(num, (int, float)):
return "Input must be a number"
# Handle special case of zero
if num == 0:
return "zero"
# Handle negative numbers
negative = num < 0
num = abs(num)
# Handle floats
if isinstance(num, float):
# Split into integer and decimal parts
int_part = int(num)
# Convert both parts
int_text = _int_to_text(int_part)
# Handle decimal part (convert to string and remove '0.')
decimal_str = str(num).split('.')[1]
decimal_text = " point " + " ".join(_digit_to_text(int(digit)) for digit in decimal_str)
result = int_text + decimal_text
else:
# Handle integers
result = _int_to_text(num)
# Add 'negative' prefix for negative numbers
if negative:
result = "negative " + result
return result
def _int_to_text(num):
"""Helper function to convert an integer to text"""
ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
"ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen",
"seventeen", "eighteen", "nineteen"]
tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]
if num < 20:
return ones[num]
if num < 100:
return tens[num // 10] + (" " + ones[num % 10] if num % 10 != 0 else "")
if num < 1000:
return ones[num // 100] + " hundred" + (" " + _int_to_text(num % 100) if num % 100 != 0 else "")
if num < 1000000:
return _int_to_text(num // 1000) + " thousand" + (" " + _int_to_text(num % 1000) if num % 1000 != 0 else "")
if num < 1000000000:
return _int_to_text(num // 1000000) + " million" + (" " + _int_to_text(num % 1000000) if num % 1000000 != 0 else "")
return _int_to_text(num // 1000000000) + " billion" + (" " + _int_to_text(num % 1000000000) if num % 1000000000 != 0 else "")
def _digit_to_text(digit):
"""Convert a single digit to text"""
digits = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
return digits[digit]
_whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = {
"en": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
],
}
def expand_abbreviations_multilingual(text, lang="en"):
for regex, replacement in _abbreviations[lang]:
text = re.sub(regex, replacement, text)
return text
_symbols_multilingual = {
"en": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " and "),
("@", " at "),
("%", " percent "),
("#", " hash "),
("$", " dollar "),
("£", " pound "),
("°", " degree "),
]
],
}
def expand_symbols_multilingual(text, lang="en"):
for regex, replacement in _symbols_multilingual[lang]:
text = re.sub(regex, replacement, text)
text = text.replace(" ", " ") # Ensure there are no double spaces
return text.strip()
_ordinal_re = {
"en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
}
_number_re = re.compile(r"[0-9]+")
_currency_re = {
"USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
"GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
"EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
}
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
def _remove_commas(m):
text = m.group(0)
if "," in text:
text = text.replace(",", "")
return text
def _remove_dots(m):
text = m.group(0)
if "." in text:
text = text.replace(".", "")
return text
def _expand_decimal_point(m, lang="en"):
amount = m.group(1).replace(",", ".")
return number_to_text(float(amount))
def _expand_currency(m, lang="en", currency="USD"):
amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
full_amount = number_to_text(amount)
and_equivalents = {
"en": ", ",
"es": " con ",
"fr": " et ",
"de": " und ",
"pt": " e ",
"it": " e ",
"pl": ", ",
"cs": ", ",
"ru": ", ",
"nl": ", ",
"ar": ", ",
"tr": ", ",
"hu": ", ",
"ko": ", ",
}
if amount.is_integer():
last_and = full_amount.rfind(and_equivalents[lang])
if last_and != -1:
full_amount = full_amount[:last_and]
return full_amount
def _expand_ordinal(m, lang="en"):
return number_to_text(int(m.group(1)), ordinal=True)
def _expand_number(m, lang="en"):
return number_to_text(int(m.group(0)))
def expand_numbers_multilingual(text, lang="en"):
if lang in ["en", "ru"]:
text = re.sub(_comma_number_re, _remove_commas, text)
else:
text = re.sub(_dot_number_re, _remove_dots, text)
try:
text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
except:
pass
text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
return text
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text)
def multilingual_cleaners(text, lang):
text = text.replace('"', "")
if lang == "tr":
text = text.replace("İ", "i")
text = text.replace("Ö", "ö")
text = text.replace("Ü", "ü")
text = lowercase(text)
try:
text = expand_numbers_multilingual(text, lang)
except:
pass
try:
text = expand_abbreviations_multilingual(text, lang)
except:
pass
try:
text = expand_symbols_multilingual(text, lang=lang)
except:
pass
text = collapse_whitespace(text)
return text
def basic_cleaners(text):
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text)
text = collapse_whitespace(text)
return text

View File

@ -0,0 +1,22 @@
{
"d_ff": 2048,
"d_kv": 64,
"d_model": 768,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "umt5",
"num_decoder_layers": 12,
"num_heads": 12,
"num_layers": 12,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 256384
}

46
comfy_extras/nodes_ace.py Normal file
View File

@ -0,0 +1,46 @@
import torch
import comfy.model_management
class TextEncodeAceStepAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "conditioning"
def encode(self, clip, tags, lyrics):
tokens = clip.tokenize(tags, lyrics=lyrics)
return (clip.encode_from_tokens_scheduled(tokens), )
class EmptyAceStepLatentAudio:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod
def INPUT_TYPES(s):
return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/audio"
def generate(self, seconds, batch_size):
length = int(seconds * 44100 / 512 / 8)
latent = torch.zeros([batch_size, 8, 16, length], device=self.device)
return ({"samples": latent, "type": "audio"}, )
NODE_CLASS_MAPPINGS = {
"TextEncodeAceStepAudio": TextEncodeAceStepAudio,
"EmptyAceStepLatentAudio": EmptyAceStepLatentAudio,
}

View File

@ -246,6 +246,9 @@ class ConditioningZeroOut:
pooled_output = d.get("pooled_output", None)
if pooled_output is not None:
d["pooled_output"] = torch.zeros_like(pooled_output)
conditioning_lyrics = d.get("conditioning_lyrics", None)
if conditioning_lyrics is not None:
d["conditioning_lyrics"] = torch.zeros_like(conditioning_lyrics)
n = [torch.zeros_like(t[0]), d]
c.append(n)
return (c, )
@ -917,7 +920,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@ -2259,6 +2262,7 @@ def init_builtin_extra_nodes():
"nodes_hidream.py",
"nodes_fresca.py",
"nodes_preview_any.py",
"nodes_ace.py",
]
import_failed = []