mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-08 07:07:14 +00:00
Initial ACE-Step model implementation. (#7972)
This commit is contained in:
parent
271c9c5b9e
commit
16417b40d9
@ -466,3 +466,7 @@ class Hunyuan3Dv2mini(LatentFormat):
|
|||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
scale_factor = 1.0188137142395404
|
scale_factor = 1.0188137142395404
|
||||||
|
|
||||||
|
class ACEAudio(LatentFormat):
|
||||||
|
latent_channels = 8
|
||||||
|
latent_dimensions = 2
|
||||||
|
768
comfy/ldm/ace/attention.py
Normal file
768
comfy/ldm/ace/attention.py
Normal file
@ -0,0 +1,768 @@
|
|||||||
|
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/attention.py
|
||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Tuple, Union, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim: int,
|
||||||
|
cross_attention_dim: Optional[int] = None,
|
||||||
|
heads: int = 8,
|
||||||
|
kv_heads: Optional[int] = None,
|
||||||
|
dim_head: int = 64,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
bias: bool = False,
|
||||||
|
qk_norm: Optional[str] = None,
|
||||||
|
added_kv_proj_dim: Optional[int] = None,
|
||||||
|
added_proj_bias: Optional[bool] = True,
|
||||||
|
out_bias: bool = True,
|
||||||
|
scale_qk: bool = True,
|
||||||
|
only_cross_attention: bool = False,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
rescale_output_factor: float = 1.0,
|
||||||
|
residual_connection: bool = False,
|
||||||
|
processor=None,
|
||||||
|
out_dim: int = None,
|
||||||
|
out_context_dim: int = None,
|
||||||
|
context_pre_only=None,
|
||||||
|
pre_only=False,
|
||||||
|
elementwise_affine: bool = True,
|
||||||
|
is_causal: bool = False,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||||
|
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
||||||
|
self.query_dim = query_dim
|
||||||
|
self.use_bias = bias
|
||||||
|
self.is_cross_attention = cross_attention_dim is not None
|
||||||
|
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||||
|
self.rescale_output_factor = rescale_output_factor
|
||||||
|
self.residual_connection = residual_connection
|
||||||
|
self.dropout = dropout
|
||||||
|
self.fused_projections = False
|
||||||
|
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||||
|
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
||||||
|
self.context_pre_only = context_pre_only
|
||||||
|
self.pre_only = pre_only
|
||||||
|
self.is_causal = is_causal
|
||||||
|
|
||||||
|
self.scale_qk = scale_qk
|
||||||
|
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||||
|
|
||||||
|
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||||
|
# for slice_size > 0 the attention score computation
|
||||||
|
# is split across the batch axis to save memory
|
||||||
|
# You can set slice_size with `set_attention_slice`
|
||||||
|
self.sliceable_head_dim = heads
|
||||||
|
|
||||||
|
self.added_kv_proj_dim = added_kv_proj_dim
|
||||||
|
self.only_cross_attention = only_cross_attention
|
||||||
|
|
||||||
|
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
||||||
|
raise ValueError(
|
||||||
|
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.group_norm = None
|
||||||
|
self.spatial_norm = None
|
||||||
|
|
||||||
|
self.norm_q = None
|
||||||
|
self.norm_k = None
|
||||||
|
|
||||||
|
self.norm_cross = None
|
||||||
|
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
if not self.only_cross_attention:
|
||||||
|
# only relevant for the `AddedKVProcessor` classes
|
||||||
|
self.to_k = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
self.to_v = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
self.to_k = None
|
||||||
|
self.to_v = None
|
||||||
|
|
||||||
|
self.added_proj_bias = added_proj_bias
|
||||||
|
if self.added_kv_proj_dim is not None:
|
||||||
|
self.add_k_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
||||||
|
self.add_v_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
||||||
|
if self.context_pre_only is not None:
|
||||||
|
self.add_q_proj = operations.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
self.add_q_proj = None
|
||||||
|
self.add_k_proj = None
|
||||||
|
self.add_v_proj = None
|
||||||
|
|
||||||
|
if not self.pre_only:
|
||||||
|
self.to_out = nn.ModuleList([])
|
||||||
|
self.to_out.append(operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device))
|
||||||
|
self.to_out.append(nn.Dropout(dropout))
|
||||||
|
else:
|
||||||
|
self.to_out = None
|
||||||
|
|
||||||
|
if self.context_pre_only is not None and not self.context_pre_only:
|
||||||
|
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
self.to_add_out = None
|
||||||
|
|
||||||
|
self.norm_added_q = None
|
||||||
|
self.norm_added_k = None
|
||||||
|
self.processor = processor
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**cross_attention_kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.processor(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
**cross_attention_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLiteLAProcessor2_0:
|
||||||
|
"""Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.kernel_func = nn.ReLU(inplace=False)
|
||||||
|
self.eps = 1e-15
|
||||||
|
self.pad_val = 1.0
|
||||||
|
|
||||||
|
def apply_rotary_emb(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||||
|
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||||
|
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||||
|
tensors contain rotary embeddings and are returned as real tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (`torch.Tensor`):
|
||||||
|
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||||
|
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||||
|
"""
|
||||||
|
cos, sin = freqs_cis # [S, D]
|
||||||
|
cos = cos[None, None]
|
||||||
|
sin = sin[None, None]
|
||||||
|
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||||
|
|
||||||
|
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||||
|
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||||
|
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
encoder_hidden_states: torch.FloatTensor = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||||
|
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
hidden_states_len = hidden_states.shape[1]
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
context_input_ndim = encoder_hidden_states.ndim
|
||||||
|
if context_input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||||
|
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
|
||||||
|
# `sample` projections.
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
key = attn.to_k(hidden_states)
|
||||||
|
value = attn.to_v(hidden_states)
|
||||||
|
|
||||||
|
# `context` projections.
|
||||||
|
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
|
||||||
|
if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
|
||||||
|
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||||
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||||
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
if not attn.is_cross_attention:
|
||||||
|
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
||||||
|
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
||||||
|
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
||||||
|
else:
|
||||||
|
query = hidden_states
|
||||||
|
key = encoder_hidden_states
|
||||||
|
value = encoder_hidden_states
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
||||||
|
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
|
||||||
|
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
||||||
|
|
||||||
|
# RoPE需要 [B, H, S, D] 输入
|
||||||
|
# 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
|
||||||
|
query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S])
|
||||||
|
|
||||||
|
# Apply query and key normalization if needed
|
||||||
|
if attn.norm_q is not None:
|
||||||
|
query = attn.norm_q(query)
|
||||||
|
if attn.norm_k is not None:
|
||||||
|
key = attn.norm_k(key)
|
||||||
|
|
||||||
|
# Apply RoPE if needed
|
||||||
|
if rotary_freqs_cis is not None:
|
||||||
|
query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
||||||
|
if not attn.is_cross_attention:
|
||||||
|
key = self.apply_rotary_emb(key, rotary_freqs_cis)
|
||||||
|
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
|
||||||
|
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
|
||||||
|
|
||||||
|
# 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
|
||||||
|
query = query.permute(0, 1, 3, 2) # [B, H, D, S]
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# attention_mask: [B, S] -> [B, 1, S, 1]
|
||||||
|
attention_mask = attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S, 1]
|
||||||
|
query = query * attention_mask.permute(0, 1, 3, 2) # [B, H, S, D] * [B, 1, S, 1]
|
||||||
|
if not attn.is_cross_attention:
|
||||||
|
key = key * attention_mask # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
|
||||||
|
value = value * attention_mask.permute(0, 1, 3, 2) # 如果 value 是 [B, h, D, S],那么需调整mask以匹配S维度
|
||||||
|
|
||||||
|
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
|
||||||
|
encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S_enc, 1]
|
||||||
|
# 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
|
||||||
|
key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1]
|
||||||
|
value = value * encoder_attention_mask.permute(0, 1, 3, 2) # [B, h, D, S_enc] * [B, 1, 1, S_enc]
|
||||||
|
|
||||||
|
query = self.kernel_func(query)
|
||||||
|
key = self.kernel_func(key)
|
||||||
|
|
||||||
|
query, key, value = query.float(), key.float(), value.float()
|
||||||
|
|
||||||
|
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
|
||||||
|
|
||||||
|
vk = torch.matmul(value, key)
|
||||||
|
|
||||||
|
hidden_states = torch.matmul(vk, query)
|
||||||
|
|
||||||
|
if hidden_states.dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
hidden_states = hidden_states.float()
|
||||||
|
|
||||||
|
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.to(dtype)
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
encoder_hidden_states = encoder_hidden_states.to(dtype)
|
||||||
|
|
||||||
|
# Split the attention outputs.
|
||||||
|
if encoder_hidden_states is not None and not attn.is_cross_attention and has_encoder_hidden_state_proj:
|
||||||
|
hidden_states, encoder_hidden_states = (
|
||||||
|
hidden_states[:, : hidden_states_len],
|
||||||
|
hidden_states[:, hidden_states_len:],
|
||||||
|
)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
if encoder_hidden_states is not None and not attn.context_pre_only and not attn.is_cross_attention and hasattr(attn, "to_add_out"):
|
||||||
|
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
if encoder_hidden_states is not None and context_input_ndim == 4:
|
||||||
|
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if torch.get_autocast_gpu_dtype() == torch.float16:
|
||||||
|
hidden_states = hidden_states.clip(-65504, 65504)
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||||
|
|
||||||
|
return hidden_states, encoder_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class CustomerAttnProcessor2_0:
|
||||||
|
r"""
|
||||||
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
|
||||||
|
def apply_rotary_emb(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||||
|
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||||
|
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||||
|
tensors contain rotary embeddings and are returned as real tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (`torch.Tensor`):
|
||||||
|
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||||
|
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||||
|
"""
|
||||||
|
cos, sin = freqs_cis # [S, D]
|
||||||
|
cos = cos[None, None]
|
||||||
|
sin = sin[None, None]
|
||||||
|
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||||
|
|
||||||
|
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||||
|
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||||
|
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
encoder_hidden_states: torch.FloatTensor = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||||
|
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states)
|
||||||
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if attn.norm_q is not None:
|
||||||
|
query = attn.norm_q(query)
|
||||||
|
if attn.norm_k is not None:
|
||||||
|
key = attn.norm_k(key)
|
||||||
|
|
||||||
|
# Apply RoPE if needed
|
||||||
|
if rotary_freqs_cis is not None:
|
||||||
|
query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
||||||
|
if not attn.is_cross_attention:
|
||||||
|
key = self.apply_rotary_emb(key, rotary_freqs_cis)
|
||||||
|
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
|
||||||
|
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
|
||||||
|
|
||||||
|
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
|
||||||
|
# attention_mask: N x S1
|
||||||
|
# encoder_attention_mask: N x S2
|
||||||
|
# cross attention 整合attention_mask和encoder_attention_mask
|
||||||
|
combined_mask = attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
|
||||||
|
attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
|
||||||
|
attention_mask = attention_mask[:, None, :, :].expand(-1, attn.heads, -1, -1).to(query.dtype)
|
||||||
|
|
||||||
|
elif not attn.is_cross_attention and attention_mask is not None:
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
|
# (batch, heads, source_length, target_length)
|
||||||
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
|
||||||
|
"""Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
return list(x)
|
||||||
|
return [x for _ in range(repeat_time)]
|
||||||
|
|
||||||
|
|
||||||
|
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
|
||||||
|
"""Return tuple with min_len by repeating element at idx_repeat."""
|
||||||
|
# convert to list first
|
||||||
|
x = val2list(x)
|
||||||
|
|
||||||
|
# repeat elements if necessary
|
||||||
|
if len(x) > 0:
|
||||||
|
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
|
||||||
|
|
||||||
|
return tuple(x)
|
||||||
|
|
||||||
|
|
||||||
|
def t2i_modulate(x, shift, scale):
|
||||||
|
return x * (1 + scale) + shift
|
||||||
|
|
||||||
|
|
||||||
|
def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
|
||||||
|
if isinstance(kernel_size, tuple):
|
||||||
|
return tuple([get_same_padding(ks) for ks in kernel_size])
|
||||||
|
else:
|
||||||
|
assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
|
||||||
|
return kernel_size // 2
|
||||||
|
|
||||||
|
class ConvLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_dim: int,
|
||||||
|
out_dim: int,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
padding: Union[int, None] = None,
|
||||||
|
use_bias=False,
|
||||||
|
norm=None,
|
||||||
|
act=None,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if padding is None:
|
||||||
|
padding = get_same_padding(kernel_size)
|
||||||
|
padding *= dilation
|
||||||
|
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self.dilation = dilation
|
||||||
|
self.groups = groups
|
||||||
|
self.padding = padding
|
||||||
|
self.use_bias = use_bias
|
||||||
|
|
||||||
|
self.conv = operations.Conv1d(
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
bias=use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype
|
||||||
|
)
|
||||||
|
if norm is not None:
|
||||||
|
self.norm = operations.RMSNorm(out_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
self.norm = None
|
||||||
|
if act is not None:
|
||||||
|
self.act = nn.SiLU(inplace=True)
|
||||||
|
else:
|
||||||
|
self.act = None
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.conv(x)
|
||||||
|
if self.norm:
|
||||||
|
x = self.norm(x)
|
||||||
|
if self.act:
|
||||||
|
x = self.act(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GLUMBConv(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_features: int,
|
||||||
|
out_feature=None,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding: Union[int, None] = None,
|
||||||
|
use_bias=False,
|
||||||
|
norm=(None, None, None),
|
||||||
|
act=("silu", "silu", None),
|
||||||
|
dilation=1,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
out_feature = out_feature or in_features
|
||||||
|
super().__init__()
|
||||||
|
use_bias = val2tuple(use_bias, 3)
|
||||||
|
norm = val2tuple(norm, 3)
|
||||||
|
act = val2tuple(act, 3)
|
||||||
|
|
||||||
|
self.glu_act = nn.SiLU(inplace=False)
|
||||||
|
self.inverted_conv = ConvLayer(
|
||||||
|
in_features,
|
||||||
|
hidden_features * 2,
|
||||||
|
1,
|
||||||
|
use_bias=use_bias[0],
|
||||||
|
norm=norm[0],
|
||||||
|
act=act[0],
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.depth_conv = ConvLayer(
|
||||||
|
hidden_features * 2,
|
||||||
|
hidden_features * 2,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
groups=hidden_features * 2,
|
||||||
|
padding=padding,
|
||||||
|
use_bias=use_bias[1],
|
||||||
|
norm=norm[1],
|
||||||
|
act=None,
|
||||||
|
dilation=dilation,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.point_conv = ConvLayer(
|
||||||
|
hidden_features,
|
||||||
|
out_feature,
|
||||||
|
1,
|
||||||
|
use_bias=use_bias[2],
|
||||||
|
norm=norm[2],
|
||||||
|
act=act[2],
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
x = self.inverted_conv(x)
|
||||||
|
x = self.depth_conv(x)
|
||||||
|
|
||||||
|
x, gate = torch.chunk(x, 2, dim=1)
|
||||||
|
gate = self.glu_act(gate)
|
||||||
|
x = x * gate
|
||||||
|
|
||||||
|
x = self.point_conv(x)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LinearTransformerBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_attention_heads,
|
||||||
|
attention_head_dim,
|
||||||
|
use_adaln_single=True,
|
||||||
|
cross_attention_dim=None,
|
||||||
|
added_kv_proj_dim=None,
|
||||||
|
context_pre_only=False,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
add_cross_attention=False,
|
||||||
|
add_cross_attention_dim=None,
|
||||||
|
qk_norm=None,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.norm1 = operations.RMSNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.attn = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
added_kv_proj_dim=added_kv_proj_dim,
|
||||||
|
dim_head=attention_head_dim,
|
||||||
|
heads=num_attention_heads,
|
||||||
|
out_dim=dim,
|
||||||
|
bias=True,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
processor=CustomLiteLAProcessor2_0(),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.add_cross_attention = add_cross_attention
|
||||||
|
self.context_pre_only = context_pre_only
|
||||||
|
|
||||||
|
if add_cross_attention and add_cross_attention_dim is not None:
|
||||||
|
self.cross_attn = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
cross_attention_dim=add_cross_attention_dim,
|
||||||
|
added_kv_proj_dim=add_cross_attention_dim,
|
||||||
|
dim_head=attention_head_dim,
|
||||||
|
heads=num_attention_heads,
|
||||||
|
out_dim=dim,
|
||||||
|
context_pre_only=context_pre_only,
|
||||||
|
bias=True,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
processor=CustomerAttnProcessor2_0(),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm2 = operations.RMSNorm(dim, 1e-06, elementwise_affine=False)
|
||||||
|
|
||||||
|
self.ff = GLUMBConv(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=int(dim * mlp_ratio),
|
||||||
|
use_bias=(True, True, False),
|
||||||
|
norm=(None, None, None),
|
||||||
|
act=("silu", "silu", None),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.use_adaln_single = use_adaln_single
|
||||||
|
if use_adaln_single:
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
encoder_hidden_states: torch.FloatTensor = None,
|
||||||
|
attention_mask: torch.FloatTensor = None,
|
||||||
|
encoder_attention_mask: torch.FloatTensor = None,
|
||||||
|
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||||
|
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||||
|
temb: torch.FloatTensor = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
N = hidden_states.shape[0]
|
||||||
|
|
||||||
|
# step 1: AdaLN single
|
||||||
|
if self.use_adaln_single:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||||
|
comfy.model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
|
||||||
|
).chunk(6, dim=1)
|
||||||
|
|
||||||
|
norm_hidden_states = self.norm1(hidden_states)
|
||||||
|
if self.use_adaln_single:
|
||||||
|
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||||
|
|
||||||
|
# step 2: attention
|
||||||
|
if not self.add_cross_attention:
|
||||||
|
attn_output, encoder_hidden_states = self.attn(
|
||||||
|
hidden_states=norm_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
rotary_freqs_cis=rotary_freqs_cis,
|
||||||
|
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output, _ = self.attn(
|
||||||
|
hidden_states=norm_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
rotary_freqs_cis=rotary_freqs_cis,
|
||||||
|
rotary_freqs_cis_cross=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_adaln_single:
|
||||||
|
attn_output = gate_msa * attn_output
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
|
if self.add_cross_attention:
|
||||||
|
attn_output = self.cross_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
rotary_freqs_cis=rotary_freqs_cis,
|
||||||
|
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||||
|
)
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
|
# step 3: add norm
|
||||||
|
norm_hidden_states = self.norm2(hidden_states)
|
||||||
|
if self.use_adaln_single:
|
||||||
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||||
|
|
||||||
|
# step 4: feed forward
|
||||||
|
ff_output = self.ff(norm_hidden_states)
|
||||||
|
if self.use_adaln_single:
|
||||||
|
ff_output = gate_mlp * ff_output
|
||||||
|
|
||||||
|
hidden_states = hidden_states + ff_output
|
||||||
|
|
||||||
|
return hidden_states
|
1067
comfy/ldm/ace/lyric_encoder.py
Normal file
1067
comfy/ldm/ace/lyric_encoder.py
Normal file
File diff suppressed because it is too large
Load Diff
381
comfy/ldm/ace/model.py
Normal file
381
comfy/ldm/ace/model.py
Normal file
@ -0,0 +1,381 @@
|
|||||||
|
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/ace_step_transformer.py
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Optional, List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||||
|
from .attention import LinearTransformerBlock, t2i_modulate
|
||||||
|
from .lyric_encoder import ConformerEncoder as LyricEncoder
|
||||||
|
|
||||||
|
|
||||||
|
def cross_norm(hidden_states, controlnet_input):
|
||||||
|
# input N x T x c
|
||||||
|
mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
|
||||||
|
mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
|
||||||
|
controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
|
||||||
|
return controlnet_input
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
|
||||||
|
class Qwen2RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, dtype=None, device=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=device).float() / self.dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
# Build here to make `torch.jit.trace` work.
|
||||||
|
self._set_cos_sin_cache(
|
||||||
|
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||||
|
|
||||||
|
freqs = torch.outer(t, self.inv_freq)
|
||||||
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||||
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x, seq_len=None):
|
||||||
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
|
if seq_len > self.max_seq_len_cached:
|
||||||
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||||
|
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class T2IFinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of Sana.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.linear = operations.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.empty(2, hidden_size, dtype=dtype, device=device))
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
def unpatchfy(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
width: int,
|
||||||
|
):
|
||||||
|
# 4 unpatchify
|
||||||
|
new_height, new_width = 1, hidden_states.size(1)
|
||||||
|
hidden_states = hidden_states.reshape(
|
||||||
|
shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
|
||||||
|
).contiguous()
|
||||||
|
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||||
|
output = hidden_states.reshape(
|
||||||
|
shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
|
||||||
|
).contiguous()
|
||||||
|
if width > new_width:
|
||||||
|
output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
|
||||||
|
elif width < new_width:
|
||||||
|
output = output[:, :, :, :width]
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(self, x, t, output_length):
|
||||||
|
shift, scale = (comfy.model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
|
||||||
|
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
# unpatchify
|
||||||
|
output = self.unpatchfy(x, output_length)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
"""2D Image to Patch Embedding"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
height=16,
|
||||||
|
width=4096,
|
||||||
|
patch_size=(16, 1),
|
||||||
|
in_channels=8,
|
||||||
|
embed_dim=1152,
|
||||||
|
bias=True,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
patch_size_h, patch_size_w = patch_size
|
||||||
|
self.early_conv_layers = nn.Sequential(
|
||||||
|
operations.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias, dtype=dtype, device=device),
|
||||||
|
operations.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True, dtype=dtype, device=device),
|
||||||
|
operations.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.height, self.width = height // patch_size_h, width // patch_size_w
|
||||||
|
self.base_size = self.width
|
||||||
|
|
||||||
|
def forward(self, latent):
|
||||||
|
# early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
|
||||||
|
latent = self.early_conv_layers(latent)
|
||||||
|
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||||
|
return latent
|
||||||
|
|
||||||
|
|
||||||
|
class ACEStepTransformer2DModel(nn.Module):
|
||||||
|
# _supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: Optional[int] = 8,
|
||||||
|
num_layers: int = 28,
|
||||||
|
inner_dim: int = 1536,
|
||||||
|
attention_head_dim: int = 64,
|
||||||
|
num_attention_heads: int = 24,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
out_channels: int = 8,
|
||||||
|
max_position: int = 32768,
|
||||||
|
rope_theta: float = 1000000.0,
|
||||||
|
speaker_embedding_dim: int = 512,
|
||||||
|
text_embedding_dim: int = 768,
|
||||||
|
ssl_encoder_depths: List[int] = [9, 9],
|
||||||
|
ssl_names: List[str] = ["mert", "m-hubert"],
|
||||||
|
ssl_latent_dims: List[int] = [1024, 768],
|
||||||
|
lyric_encoder_vocab_size: int = 6681,
|
||||||
|
lyric_hidden_size: int = 1024,
|
||||||
|
patch_size: List[int] = [16, 1],
|
||||||
|
max_height: int = 16,
|
||||||
|
max_width: int = 4096,
|
||||||
|
audio_model=None,
|
||||||
|
dtype=None, device=None, operations=None
|
||||||
|
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dtype = dtype
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.attention_head_dim = attention_head_dim
|
||||||
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.inner_dim = inner_dim
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.max_position = max_position
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
|
self.rotary_emb = Qwen2RotaryEmbedding(
|
||||||
|
dim=self.attention_head_dim,
|
||||||
|
max_position_embeddings=self.max_position,
|
||||||
|
base=self.rope_theta,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Define input layers
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.num_layers = num_layers
|
||||||
|
# 3. Define transformers blocks
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
LinearTransformerBlock(
|
||||||
|
dim=self.inner_dim,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
add_cross_attention=True,
|
||||||
|
add_cross_attention_dim=self.inner_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
for i in range(self.num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||||
|
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.t_block = nn.Sequential(nn.SiLU(), operations.Linear(self.inner_dim, 6 * self.inner_dim, bias=True, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
# speaker
|
||||||
|
self.speaker_embedder = operations.Linear(speaker_embedding_dim, self.inner_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# genre
|
||||||
|
self.genre_embedder = operations.Linear(text_embedding_dim, self.inner_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# lyric
|
||||||
|
self.lyric_embs = operations.Embedding(lyric_encoder_vocab_size, lyric_hidden_size, dtype=dtype, device=device)
|
||||||
|
self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.lyric_proj = operations.Linear(lyric_hidden_size, self.inner_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
projector_dim = 2 * self.inner_dim
|
||||||
|
|
||||||
|
self.projectors = nn.ModuleList([
|
||||||
|
nn.Sequential(
|
||||||
|
operations.Linear(self.inner_dim, projector_dim, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(projector_dim, projector_dim, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(projector_dim, ssl_dim, dtype=dtype, device=device),
|
||||||
|
) for ssl_dim in ssl_latent_dims
|
||||||
|
])
|
||||||
|
|
||||||
|
self.proj_in = PatchEmbed(
|
||||||
|
height=max_height,
|
||||||
|
width=max_width,
|
||||||
|
patch_size=patch_size,
|
||||||
|
embed_dim=self.inner_dim,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def forward_lyric_encoder(
|
||||||
|
self,
|
||||||
|
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||||
|
lyric_mask: Optional[torch.LongTensor] = None,
|
||||||
|
out_dtype=None,
|
||||||
|
):
|
||||||
|
# N x T x D
|
||||||
|
lyric_embs = self.lyric_embs(lyric_token_idx, out_dtype=out_dtype)
|
||||||
|
prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
|
||||||
|
prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
|
||||||
|
return prompt_prenet_out
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
encoder_text_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||||
|
lyric_mask: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
bs = encoder_text_hidden_states.shape[0]
|
||||||
|
device = encoder_text_hidden_states.device
|
||||||
|
|
||||||
|
# speaker embedding
|
||||||
|
encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
|
||||||
|
|
||||||
|
# genre embedding
|
||||||
|
encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
|
||||||
|
|
||||||
|
# lyric
|
||||||
|
encoder_lyric_hidden_states = self.forward_lyric_encoder(
|
||||||
|
lyric_token_idx=lyric_token_idx,
|
||||||
|
lyric_mask=lyric_mask,
|
||||||
|
out_dtype=encoder_text_hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
|
||||||
|
|
||||||
|
encoder_hidden_mask = None
|
||||||
|
if text_attention_mask is not None:
|
||||||
|
speaker_mask = torch.ones(bs, 1, device=device)
|
||||||
|
encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
|
||||||
|
|
||||||
|
return encoder_hidden_states, encoder_hidden_mask
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_mask: torch.Tensor,
|
||||||
|
timestep: Optional[torch.Tensor],
|
||||||
|
output_length: int = 0,
|
||||||
|
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||||
|
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||||
|
return_dict: bool = True,
|
||||||
|
):
|
||||||
|
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
||||||
|
temb = self.t_block(embedded_timestep)
|
||||||
|
|
||||||
|
hidden_states = self.proj_in(hidden_states)
|
||||||
|
|
||||||
|
# controlnet logic
|
||||||
|
if block_controlnet_hidden_states is not None:
|
||||||
|
control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
|
||||||
|
hidden_states = hidden_states + control_condi * controlnet_scale
|
||||||
|
|
||||||
|
# inner_hidden_states = []
|
||||||
|
|
||||||
|
rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
|
||||||
|
encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
|
||||||
|
|
||||||
|
for index_block, block in enumerate(self.transformer_blocks):
|
||||||
|
hidden_states = block(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_hidden_mask,
|
||||||
|
rotary_freqs_cis=rotary_freqs_cis,
|
||||||
|
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
||||||
|
temb=temb,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
attention_mask=None,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||||
|
lyric_mask: Optional[torch.LongTensor] = None,
|
||||||
|
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||||
|
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
hidden_states = x
|
||||||
|
encoder_text_hidden_states = context
|
||||||
|
encoder_hidden_states, encoder_hidden_mask = self.encode(
|
||||||
|
encoder_text_hidden_states=encoder_text_hidden_states,
|
||||||
|
text_attention_mask=text_attention_mask,
|
||||||
|
speaker_embeds=speaker_embeds,
|
||||||
|
lyric_token_idx=lyric_token_idx,
|
||||||
|
lyric_mask=lyric_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_length = hidden_states.shape[-1]
|
||||||
|
|
||||||
|
output = self.decode(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_hidden_mask=encoder_hidden_mask,
|
||||||
|
timestep=timestep,
|
||||||
|
output_length=output_length,
|
||||||
|
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
||||||
|
controlnet_scale=controlnet_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
644
comfy/ldm/ace/vae/autoencoder_dc.py
Normal file
644
comfy/ldm/ace/vae/autoencoder_dc.py
Normal file
@ -0,0 +1,644 @@
|
|||||||
|
# Rewritten from diffusers
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(ops.RMSNorm):
|
||||||
|
def __init__(self, dim, eps=1e-5, elementwise_affine=True, bias=False):
|
||||||
|
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
if elementwise_affine:
|
||||||
|
self.bias = nn.Parameter(torch.empty(dim)) if bias else None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = super().forward(x)
|
||||||
|
if self.elementwise_affine:
|
||||||
|
if self.bias is not None:
|
||||||
|
x = x + comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def get_normalization(norm_type, num_features, num_groups=32, eps=1e-5):
|
||||||
|
if norm_type == "batch_norm":
|
||||||
|
return nn.BatchNorm2d(num_features)
|
||||||
|
elif norm_type == "group_norm":
|
||||||
|
return ops.GroupNorm(num_groups, num_features)
|
||||||
|
elif norm_type == "layer_norm":
|
||||||
|
return ops.LayerNorm(num_features)
|
||||||
|
elif norm_type == "rms_norm":
|
||||||
|
return RMSNorm(num_features, eps=eps, elementwise_affine=True, bias=True)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown normalization type: {norm_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation(activation_type):
|
||||||
|
if activation_type == "relu":
|
||||||
|
return nn.ReLU()
|
||||||
|
elif activation_type == "relu6":
|
||||||
|
return nn.ReLU6()
|
||||||
|
elif activation_type == "silu":
|
||||||
|
return nn.SiLU()
|
||||||
|
elif activation_type == "leaky_relu":
|
||||||
|
return nn.LeakyReLU(0.2)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown activation type: {activation_type}")
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
norm_type: str = "batch_norm",
|
||||||
|
act_fn: str = "relu6",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
|
||||||
|
self.conv1 = ops.Conv2d(in_channels, in_channels, 3, 1, 1)
|
||||||
|
self.conv2 = ops.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
|
||||||
|
self.norm = get_normalization(norm_type, out_channels)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.conv1(hidden_states)
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
hidden_states = self.conv2(hidden_states)
|
||||||
|
|
||||||
|
if self.norm_type == "rms_norm":
|
||||||
|
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
||||||
|
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||||
|
else:
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states + residual
|
||||||
|
|
||||||
|
class SanaMultiscaleAttentionProjection(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
kernel_size: int,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
channels = 3 * in_channels
|
||||||
|
self.proj_in = ops.Conv2d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
padding=kernel_size // 2,
|
||||||
|
groups=channels,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.proj_out = ops.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.proj_in(hidden_states)
|
||||||
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class SanaMultiscaleLinearAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_attention_heads: int = None,
|
||||||
|
attention_head_dim: int = 8,
|
||||||
|
mult: float = 1.0,
|
||||||
|
norm_type: str = "batch_norm",
|
||||||
|
kernel_sizes: tuple = (5,),
|
||||||
|
eps: float = 1e-15,
|
||||||
|
residual_connection: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.eps = eps
|
||||||
|
self.attention_head_dim = attention_head_dim
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.residual_connection = residual_connection
|
||||||
|
|
||||||
|
num_attention_heads = (
|
||||||
|
int(in_channels // attention_head_dim * mult)
|
||||||
|
if num_attention_heads is None
|
||||||
|
else num_attention_heads
|
||||||
|
)
|
||||||
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
|
||||||
|
self.to_q = ops.Linear(in_channels, inner_dim, bias=False)
|
||||||
|
self.to_k = ops.Linear(in_channels, inner_dim, bias=False)
|
||||||
|
self.to_v = ops.Linear(in_channels, inner_dim, bias=False)
|
||||||
|
|
||||||
|
self.to_qkv_multiscale = nn.ModuleList()
|
||||||
|
for kernel_size in kernel_sizes:
|
||||||
|
self.to_qkv_multiscale.append(
|
||||||
|
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.nonlinearity = nn.ReLU()
|
||||||
|
self.to_out = ops.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
|
||||||
|
self.norm_out = get_normalization(norm_type, out_channels)
|
||||||
|
|
||||||
|
def apply_linear_attention(self, query, key, value):
|
||||||
|
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1)
|
||||||
|
scores = torch.matmul(value, key.transpose(-1, -2))
|
||||||
|
hidden_states = torch.matmul(scores, query)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||||
|
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def apply_quadratic_attention(self, query, key, value):
|
||||||
|
scores = torch.matmul(key.transpose(-1, -2), query)
|
||||||
|
scores = scores.to(dtype=torch.float32)
|
||||||
|
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
|
||||||
|
hidden_states = torch.matmul(value, scores.to(value.dtype))
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
height, width = hidden_states.shape[-2:]
|
||||||
|
if height * width > self.attention_head_dim:
|
||||||
|
use_linear_attention = True
|
||||||
|
else:
|
||||||
|
use_linear_attention = False
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
batch_size, _, height, width = list(hidden_states.size())
|
||||||
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
|
hidden_states = hidden_states.movedim(1, -1)
|
||||||
|
query = self.to_q(hidden_states)
|
||||||
|
key = self.to_k(hidden_states)
|
||||||
|
value = self.to_v(hidden_states)
|
||||||
|
hidden_states = torch.cat([query, key, value], dim=3)
|
||||||
|
hidden_states = hidden_states.movedim(-1, 1)
|
||||||
|
|
||||||
|
multi_scale_qkv = [hidden_states]
|
||||||
|
for block in self.to_qkv_multiscale:
|
||||||
|
multi_scale_qkv.append(block(hidden_states))
|
||||||
|
|
||||||
|
hidden_states = torch.cat(multi_scale_qkv, dim=1)
|
||||||
|
|
||||||
|
if use_linear_attention:
|
||||||
|
# for linear attention upcast hidden_states to float32
|
||||||
|
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.reshape(batch_size, -1, 3 * self.attention_head_dim, height * width)
|
||||||
|
|
||||||
|
query, key, value = hidden_states.chunk(3, dim=2)
|
||||||
|
query = self.nonlinearity(query)
|
||||||
|
key = self.nonlinearity(key)
|
||||||
|
|
||||||
|
if use_linear_attention:
|
||||||
|
hidden_states = self.apply_linear_attention(query, key, value)
|
||||||
|
hidden_states = hidden_states.to(dtype=original_dtype)
|
||||||
|
else:
|
||||||
|
hidden_states = self.apply_quadratic_attention(query, key, value)
|
||||||
|
|
||||||
|
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
|
||||||
|
hidden_states = self.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||||
|
|
||||||
|
if self.norm_type == "rms_norm":
|
||||||
|
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||||
|
else:
|
||||||
|
hidden_states = self.norm_out(hidden_states)
|
||||||
|
|
||||||
|
if self.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class EfficientViTBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
mult: float = 1.0,
|
||||||
|
attention_head_dim: int = 32,
|
||||||
|
qkv_multiscales: tuple = (5,),
|
||||||
|
norm_type: str = "batch_norm",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn = SanaMultiscaleLinearAttention(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
mult=mult,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
norm_type=norm_type,
|
||||||
|
kernel_sizes=qkv_multiscales,
|
||||||
|
residual_connection=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_out = GLUMBConv(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
norm_type="rms_norm",
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.attn(x)
|
||||||
|
x = self.conv_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GLUMBConv(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
expand_ratio: float = 4,
|
||||||
|
norm_type: str = None,
|
||||||
|
residual_connection: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
hidden_channels = int(expand_ratio * in_channels)
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.residual_connection = residual_connection
|
||||||
|
|
||||||
|
self.nonlinearity = nn.SiLU()
|
||||||
|
self.conv_inverted = ops.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
|
||||||
|
self.conv_depth = ops.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
|
||||||
|
self.conv_point = ops.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
|
||||||
|
|
||||||
|
self.norm = None
|
||||||
|
if norm_type == "rms_norm":
|
||||||
|
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.residual_connection:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.conv_inverted(hidden_states)
|
||||||
|
hidden_states = self.nonlinearity(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.conv_depth(hidden_states)
|
||||||
|
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
|
||||||
|
hidden_states = hidden_states * self.nonlinearity(gate)
|
||||||
|
|
||||||
|
hidden_states = self.conv_point(hidden_states)
|
||||||
|
|
||||||
|
if self.norm_type == "rms_norm":
|
||||||
|
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
||||||
|
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||||
|
|
||||||
|
if self.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def get_block(
|
||||||
|
block_type: str,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
attention_head_dim: int,
|
||||||
|
norm_type: str,
|
||||||
|
act_fn: str,
|
||||||
|
qkv_mutliscales: tuple = (),
|
||||||
|
):
|
||||||
|
if block_type == "ResBlock":
|
||||||
|
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
||||||
|
elif block_type == "EfficientViTBlock":
|
||||||
|
block = EfficientViTBlock(
|
||||||
|
in_channels,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
norm_type=norm_type,
|
||||||
|
qkv_multiscales=qkv_mutliscales
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Block with {block_type=} is not supported.")
|
||||||
|
|
||||||
|
return block
|
||||||
|
|
||||||
|
|
||||||
|
class DCDownBlock2d(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.downsample = downsample
|
||||||
|
self.factor = 2
|
||||||
|
self.stride = 1 if downsample else 2
|
||||||
|
self.group_size = in_channels * self.factor**2 // out_channels
|
||||||
|
self.shortcut = shortcut
|
||||||
|
|
||||||
|
out_ratio = self.factor**2
|
||||||
|
if downsample:
|
||||||
|
assert out_channels % out_ratio == 0
|
||||||
|
out_channels = out_channels // out_ratio
|
||||||
|
|
||||||
|
self.conv = ops.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=self.stride,
|
||||||
|
padding=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.conv(hidden_states)
|
||||||
|
if self.downsample:
|
||||||
|
x = F.pixel_unshuffle(x, self.factor)
|
||||||
|
|
||||||
|
if self.shortcut:
|
||||||
|
y = F.pixel_unshuffle(hidden_states, self.factor)
|
||||||
|
y = y.unflatten(1, (-1, self.group_size))
|
||||||
|
y = y.mean(dim=2)
|
||||||
|
hidden_states = x + y
|
||||||
|
else:
|
||||||
|
hidden_states = x
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class DCUpBlock2d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
interpolate: bool = False,
|
||||||
|
shortcut: bool = True,
|
||||||
|
interpolation_mode: str = "nearest",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.interpolate = interpolate
|
||||||
|
self.interpolation_mode = interpolation_mode
|
||||||
|
self.shortcut = shortcut
|
||||||
|
self.factor = 2
|
||||||
|
self.repeats = out_channels * self.factor**2 // in_channels
|
||||||
|
|
||||||
|
out_ratio = self.factor**2
|
||||||
|
if not interpolate:
|
||||||
|
out_channels = out_channels * out_ratio
|
||||||
|
|
||||||
|
self.conv = ops.Conv2d(in_channels, out_channels, 3, 1, 1)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.interpolate:
|
||||||
|
x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
|
||||||
|
x = self.conv(x)
|
||||||
|
else:
|
||||||
|
x = self.conv(hidden_states)
|
||||||
|
x = F.pixel_shuffle(x, self.factor)
|
||||||
|
|
||||||
|
if self.shortcut:
|
||||||
|
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
|
||||||
|
y = F.pixel_shuffle(y, self.factor)
|
||||||
|
hidden_states = x + y
|
||||||
|
else:
|
||||||
|
hidden_states = x
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
latent_channels: int,
|
||||||
|
attention_head_dim: int = 32,
|
||||||
|
block_type: str or tuple = "ResBlock",
|
||||||
|
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
|
||||||
|
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
|
||||||
|
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
|
||||||
|
downsample_block_type: str = "pixel_unshuffle",
|
||||||
|
out_shortcut: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
num_blocks = len(block_out_channels)
|
||||||
|
|
||||||
|
if isinstance(block_type, str):
|
||||||
|
block_type = (block_type,) * num_blocks
|
||||||
|
|
||||||
|
if layers_per_block[0] > 0:
|
||||||
|
self.conv_in = ops.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.conv_in = DCDownBlock2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
|
||||||
|
downsample=downsample_block_type == "pixel_unshuffle",
|
||||||
|
shortcut=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_blocks = []
|
||||||
|
for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)):
|
||||||
|
down_block_list = []
|
||||||
|
|
||||||
|
for _ in range(num_layers):
|
||||||
|
block = get_block(
|
||||||
|
block_type[i],
|
||||||
|
out_channel,
|
||||||
|
out_channel,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
norm_type="rms_norm",
|
||||||
|
act_fn="silu",
|
||||||
|
qkv_mutliscales=qkv_multiscales[i],
|
||||||
|
)
|
||||||
|
down_block_list.append(block)
|
||||||
|
|
||||||
|
if i < num_blocks - 1 and num_layers > 0:
|
||||||
|
downsample_block = DCDownBlock2d(
|
||||||
|
in_channels=out_channel,
|
||||||
|
out_channels=block_out_channels[i + 1],
|
||||||
|
downsample=downsample_block_type == "pixel_unshuffle",
|
||||||
|
shortcut=True,
|
||||||
|
)
|
||||||
|
down_block_list.append(downsample_block)
|
||||||
|
|
||||||
|
down_blocks.append(nn.Sequential(*down_block_list))
|
||||||
|
|
||||||
|
self.down_blocks = nn.ModuleList(down_blocks)
|
||||||
|
|
||||||
|
self.conv_out = ops.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1)
|
||||||
|
|
||||||
|
self.out_shortcut = out_shortcut
|
||||||
|
if out_shortcut:
|
||||||
|
self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.conv_in(hidden_states)
|
||||||
|
for down_block in self.down_blocks:
|
||||||
|
hidden_states = down_block(hidden_states)
|
||||||
|
|
||||||
|
if self.out_shortcut:
|
||||||
|
x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size))
|
||||||
|
x = x.mean(dim=2)
|
||||||
|
hidden_states = self.conv_out(hidden_states) + x
|
||||||
|
else:
|
||||||
|
hidden_states = self.conv_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
latent_channels: int,
|
||||||
|
attention_head_dim: int = 32,
|
||||||
|
block_type: str or tuple = "ResBlock",
|
||||||
|
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
|
||||||
|
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
|
||||||
|
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
|
||||||
|
norm_type: str or tuple = "rms_norm",
|
||||||
|
act_fn: str or tuple = "silu",
|
||||||
|
upsample_block_type: str = "pixel_shuffle",
|
||||||
|
in_shortcut: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
num_blocks = len(block_out_channels)
|
||||||
|
|
||||||
|
if isinstance(block_type, str):
|
||||||
|
block_type = (block_type,) * num_blocks
|
||||||
|
if isinstance(norm_type, str):
|
||||||
|
norm_type = (norm_type,) * num_blocks
|
||||||
|
if isinstance(act_fn, str):
|
||||||
|
act_fn = (act_fn,) * num_blocks
|
||||||
|
|
||||||
|
self.conv_in = ops.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1)
|
||||||
|
|
||||||
|
self.in_shortcut = in_shortcut
|
||||||
|
if in_shortcut:
|
||||||
|
self.in_shortcut_repeats = block_out_channels[-1] // latent_channels
|
||||||
|
|
||||||
|
up_blocks = []
|
||||||
|
for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
|
||||||
|
up_block_list = []
|
||||||
|
|
||||||
|
if i < num_blocks - 1 and num_layers > 0:
|
||||||
|
upsample_block = DCUpBlock2d(
|
||||||
|
block_out_channels[i + 1],
|
||||||
|
out_channel,
|
||||||
|
interpolate=upsample_block_type == "interpolate",
|
||||||
|
shortcut=True,
|
||||||
|
)
|
||||||
|
up_block_list.append(upsample_block)
|
||||||
|
|
||||||
|
for _ in range(num_layers):
|
||||||
|
block = get_block(
|
||||||
|
block_type[i],
|
||||||
|
out_channel,
|
||||||
|
out_channel,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
norm_type=norm_type[i],
|
||||||
|
act_fn=act_fn[i],
|
||||||
|
qkv_mutliscales=qkv_multiscales[i],
|
||||||
|
)
|
||||||
|
up_block_list.append(block)
|
||||||
|
|
||||||
|
up_blocks.insert(0, nn.Sequential(*up_block_list))
|
||||||
|
|
||||||
|
self.up_blocks = nn.ModuleList(up_blocks)
|
||||||
|
|
||||||
|
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
|
||||||
|
|
||||||
|
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
|
||||||
|
self.conv_act = nn.ReLU()
|
||||||
|
self.conv_out = None
|
||||||
|
|
||||||
|
if layers_per_block[0] > 0:
|
||||||
|
self.conv_out = ops.Conv2d(channels, in_channels, 3, 1, 1)
|
||||||
|
else:
|
||||||
|
self.conv_out = DCUpBlock2d(
|
||||||
|
channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.in_shortcut:
|
||||||
|
x = hidden_states.repeat_interleave(
|
||||||
|
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
|
||||||
|
)
|
||||||
|
hidden_states = self.conv_in(hidden_states) + x
|
||||||
|
else:
|
||||||
|
hidden_states = self.conv_in(hidden_states)
|
||||||
|
|
||||||
|
for up_block in reversed(self.up_blocks):
|
||||||
|
hidden_states = up_block(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||||
|
hidden_states = self.conv_act(hidden_states)
|
||||||
|
hidden_states = self.conv_out(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencoderDC(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 2,
|
||||||
|
latent_channels: int = 8,
|
||||||
|
attention_head_dim: int = 32,
|
||||||
|
encoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
|
||||||
|
decoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
|
||||||
|
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
|
||||||
|
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
|
||||||
|
encoder_layers_per_block: Tuple[int] = (2, 2, 3, 3),
|
||||||
|
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3),
|
||||||
|
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
|
||||||
|
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
|
||||||
|
upsample_block_type: str = "interpolate",
|
||||||
|
downsample_block_type: str = "Conv",
|
||||||
|
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
|
||||||
|
decoder_act_fns: Union[str, Tuple[str]] = "silu",
|
||||||
|
scaling_factor: float = 0.41407,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.encoder = Encoder(
|
||||||
|
in_channels=in_channels,
|
||||||
|
latent_channels=latent_channels,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
block_type=encoder_block_types,
|
||||||
|
block_out_channels=encoder_block_out_channels,
|
||||||
|
layers_per_block=encoder_layers_per_block,
|
||||||
|
qkv_multiscales=encoder_qkv_multiscales,
|
||||||
|
downsample_block_type=downsample_block_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decoder = Decoder(
|
||||||
|
in_channels=in_channels,
|
||||||
|
latent_channels=latent_channels,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
block_type=decoder_block_types,
|
||||||
|
block_out_channels=decoder_block_out_channels,
|
||||||
|
layers_per_block=decoder_layers_per_block,
|
||||||
|
qkv_multiscales=decoder_qkv_multiscales,
|
||||||
|
norm_type=decoder_norm_types,
|
||||||
|
act_fn=decoder_act_fns,
|
||||||
|
upsample_block_type=upsample_block_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
|
||||||
|
|
||||||
|
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Internal encoding function."""
|
||||||
|
encoded = self.encoder(x)
|
||||||
|
return encoded * self.scaling_factor
|
||||||
|
|
||||||
|
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Scale the latents back
|
||||||
|
z = z / self.scaling_factor
|
||||||
|
decoded = self.decoder(z)
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
z = self.encode(x)
|
||||||
|
return self.decode(z)
|
||||||
|
|
104
comfy/ldm/ace/vae/music_dcae_pipeline.py
Normal file
104
comfy/ldm/ace/vae/music_dcae_pipeline.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py
|
||||||
|
import torch
|
||||||
|
from .autoencoder_dc import AutoencoderDC
|
||||||
|
import torchaudio
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from .music_vocoder import ADaMoSHiFiGANV1
|
||||||
|
|
||||||
|
|
||||||
|
class MusicDCAE(torch.nn.Module):
|
||||||
|
def __init__(self, source_sample_rate=None, dcae_config={}, vocoder_config={}):
|
||||||
|
super(MusicDCAE, self).__init__()
|
||||||
|
|
||||||
|
self.dcae = AutoencoderDC(**dcae_config)
|
||||||
|
self.vocoder = ADaMoSHiFiGANV1(**vocoder_config)
|
||||||
|
|
||||||
|
if source_sample_rate is None:
|
||||||
|
self.source_sample_rate = 48000
|
||||||
|
else:
|
||||||
|
self.source_sample_rate = source_sample_rate
|
||||||
|
|
||||||
|
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
|
||||||
|
|
||||||
|
self.transform = transforms.Compose([
|
||||||
|
transforms.Normalize(0.5, 0.5),
|
||||||
|
])
|
||||||
|
self.min_mel_value = -11.0
|
||||||
|
self.max_mel_value = 3.0
|
||||||
|
self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
|
||||||
|
self.mel_chunk_size = 1024
|
||||||
|
self.time_dimention_multiple = 8
|
||||||
|
self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
|
||||||
|
self.scale_factor = 0.1786
|
||||||
|
self.shift_factor = -1.9091
|
||||||
|
|
||||||
|
def load_audio(self, audio_path):
|
||||||
|
audio, sr = torchaudio.load(audio_path)
|
||||||
|
return audio, sr
|
||||||
|
|
||||||
|
def forward_mel(self, audios):
|
||||||
|
mels = []
|
||||||
|
for i in range(len(audios)):
|
||||||
|
image = self.vocoder.mel_transform(audios[i])
|
||||||
|
mels.append(image)
|
||||||
|
mels = torch.stack(mels)
|
||||||
|
return mels
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def encode(self, audios, audio_lengths=None, sr=None):
|
||||||
|
if audio_lengths is None:
|
||||||
|
audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
|
||||||
|
audio_lengths = audio_lengths.to(audios.device)
|
||||||
|
|
||||||
|
if sr is None:
|
||||||
|
sr = self.source_sample_rate
|
||||||
|
|
||||||
|
if sr != 44100:
|
||||||
|
audios = torchaudio.functional.resample(audios, sr, 44100)
|
||||||
|
|
||||||
|
max_audio_len = audios.shape[-1]
|
||||||
|
if max_audio_len % (8 * 512) != 0:
|
||||||
|
audios = torch.nn.functional.pad(audios, (0, 8 * 512 - max_audio_len % (8 * 512)))
|
||||||
|
|
||||||
|
mels = self.forward_mel(audios)
|
||||||
|
mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
|
||||||
|
mels = self.transform(mels)
|
||||||
|
latents = []
|
||||||
|
for mel in mels:
|
||||||
|
latent = self.dcae.encoder(mel.unsqueeze(0))
|
||||||
|
latents.append(latent)
|
||||||
|
latents = torch.cat(latents, dim=0)
|
||||||
|
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
|
||||||
|
latents = (latents - self.shift_factor) * self.scale_factor
|
||||||
|
return latents
|
||||||
|
# return latents, latent_lengths
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode(self, latents, audio_lengths=None, sr=None):
|
||||||
|
latents = latents / self.scale_factor + self.shift_factor
|
||||||
|
|
||||||
|
pred_wavs = []
|
||||||
|
|
||||||
|
for latent in latents:
|
||||||
|
mels = self.dcae.decoder(latent.unsqueeze(0))
|
||||||
|
mels = mels * 0.5 + 0.5
|
||||||
|
mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
|
||||||
|
wav = self.vocoder.decode(mels[0]).squeeze(1)
|
||||||
|
|
||||||
|
if sr is not None:
|
||||||
|
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
|
||||||
|
wav = torchaudio.functional.resample(wav, 44100, sr)
|
||||||
|
# wav = resampler(wav)
|
||||||
|
else:
|
||||||
|
sr = 44100
|
||||||
|
pred_wavs.append(wav)
|
||||||
|
|
||||||
|
if audio_lengths is not None:
|
||||||
|
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
|
||||||
|
return torch.stack(pred_wavs)
|
||||||
|
# return sr, pred_wavs
|
||||||
|
|
||||||
|
def forward(self, audios, audio_lengths=None, sr=None):
|
||||||
|
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
|
||||||
|
sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
|
||||||
|
return sr, pred_wavs, latents, latent_lengths
|
108
comfy/ldm/ace/vae/music_log_mel.py
Executable file
108
comfy/ldm/ace/vae/music_log_mel.py
Executable file
@ -0,0 +1,108 @@
|
|||||||
|
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torchaudio.transforms import MelScale
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
class LinearSpectrogram(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_fft=2048,
|
||||||
|
win_length=2048,
|
||||||
|
hop_length=512,
|
||||||
|
center=False,
|
||||||
|
mode="pow2_sqrt",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.win_length = win_length
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.center = center
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
self.register_buffer("window", torch.hann_window(win_length))
|
||||||
|
|
||||||
|
def forward(self, y: Tensor) -> Tensor:
|
||||||
|
if y.ndim == 3:
|
||||||
|
y = y.squeeze(1)
|
||||||
|
|
||||||
|
y = torch.nn.functional.pad(
|
||||||
|
y.unsqueeze(1),
|
||||||
|
(
|
||||||
|
(self.win_length - self.hop_length) // 2,
|
||||||
|
(self.win_length - self.hop_length + 1) // 2,
|
||||||
|
),
|
||||||
|
mode="reflect",
|
||||||
|
).squeeze(1)
|
||||||
|
dtype = y.dtype
|
||||||
|
spec = torch.stft(
|
||||||
|
y.float(),
|
||||||
|
self.n_fft,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
win_length=self.win_length,
|
||||||
|
window=comfy.model_management.cast_to(self.window, dtype=torch.float32, device=y.device),
|
||||||
|
center=self.center,
|
||||||
|
pad_mode="reflect",
|
||||||
|
normalized=False,
|
||||||
|
onesided=True,
|
||||||
|
return_complex=True,
|
||||||
|
)
|
||||||
|
spec = torch.view_as_real(spec)
|
||||||
|
|
||||||
|
if self.mode == "pow2_sqrt":
|
||||||
|
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||||
|
spec = spec.to(dtype)
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
class LogMelSpectrogram(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_rate=44100,
|
||||||
|
n_fft=2048,
|
||||||
|
win_length=2048,
|
||||||
|
hop_length=512,
|
||||||
|
n_mels=128,
|
||||||
|
center=False,
|
||||||
|
f_min=0.0,
|
||||||
|
f_max=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.win_length = win_length
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.center = center
|
||||||
|
self.n_mels = n_mels
|
||||||
|
self.f_min = f_min
|
||||||
|
self.f_max = f_max or sample_rate // 2
|
||||||
|
|
||||||
|
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
|
||||||
|
self.mel_scale = MelScale(
|
||||||
|
self.n_mels,
|
||||||
|
self.sample_rate,
|
||||||
|
self.f_min,
|
||||||
|
self.f_max,
|
||||||
|
self.n_fft // 2 + 1,
|
||||||
|
"slaney",
|
||||||
|
"slaney",
|
||||||
|
)
|
||||||
|
|
||||||
|
def compress(self, x: Tensor) -> Tensor:
|
||||||
|
return torch.log(torch.clamp(x, min=1e-5))
|
||||||
|
|
||||||
|
def decompress(self, x: Tensor) -> Tensor:
|
||||||
|
return torch.exp(x)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
|
||||||
|
linear = self.spectrogram(x)
|
||||||
|
x = self.mel_scale(linear)
|
||||||
|
x = self.compress(x)
|
||||||
|
# print(x.shape)
|
||||||
|
if return_linear:
|
||||||
|
return x, self.compress(linear)
|
||||||
|
|
||||||
|
return x
|
542
comfy/ldm/ace/vae/music_vocoder.py
Executable file
542
comfy/ldm/ace/vae/music_vocoder.py
Executable file
@ -0,0 +1,542 @@
|
|||||||
|
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_vocoder.py
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from math import prod
|
||||||
|
from typing import Callable, Tuple, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils import weight_norm
|
||||||
|
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
|
||||||
|
# from diffusers.models.modeling_utils import ModelMixin
|
||||||
|
# from diffusers.loaders import FromOriginalModelMixin
|
||||||
|
# from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
|
|
||||||
|
from .music_log_mel import LogMelSpectrogram
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
|
||||||
|
def drop_path(
|
||||||
|
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
||||||
|
):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
|
||||||
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||||
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||||
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||||
|
'survival rate' as the argument.
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
if drop_prob == 0.0 or not training:
|
||||||
|
return x
|
||||||
|
keep_prob = 1 - drop_prob
|
||||||
|
shape = (x.shape[0],) + (1,) * (
|
||||||
|
x.ndim - 1
|
||||||
|
) # work with diff dim tensors, not just 2D ConvNets
|
||||||
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||||
|
if keep_prob > 0.0 and scale_by_keep:
|
||||||
|
random_tensor.div_(keep_prob)
|
||||||
|
return x * random_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class DropPath(nn.Module):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
|
||||||
|
|
||||||
|
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
||||||
|
super(DropPath, self).__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
self.scale_by_keep = scale_by_keep
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
||||||
|
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
||||||
|
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
||||||
|
with shape (batch_size, channels, height, width).
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||||
|
self.eps = eps
|
||||||
|
self.data_format = data_format
|
||||||
|
if self.data_format not in ["channels_last", "channels_first"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
self.normalized_shape = (normalized_shape,)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.data_format == "channels_last":
|
||||||
|
return F.layer_norm(
|
||||||
|
x, self.normalized_shape, comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device), comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device), self.eps
|
||||||
|
)
|
||||||
|
elif self.data_format == "channels_first":
|
||||||
|
u = x.mean(1, keepdim=True)
|
||||||
|
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||||
|
x = (x - u) / torch.sqrt(s + self.eps)
|
||||||
|
x = comfy.model_management.cast_to(self.weight[:, None], dtype=x.dtype, device=x.device) * x + comfy.model_management.cast_to(self.bias[:, None], dtype=x.dtype, device=x.device)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConvNeXtBlock(nn.Module):
|
||||||
|
r"""ConvNeXt Block. There are two equivalent implementations:
|
||||||
|
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
||||||
|
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
||||||
|
We use (2) as we find it slightly faster in PyTorch
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
drop_path (float): Stochastic depth rate. Default: 0.0
|
||||||
|
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
||||||
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
|
||||||
|
kernel_size (int): Kernel size for depthwise conv. Default: 7.
|
||||||
|
dilation (int): Dilation for depthwise conv. Default: 1.
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
drop_path: float = 0.0,
|
||||||
|
layer_scale_init_value: float = 1e-6,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
kernel_size: int = 7,
|
||||||
|
dilation: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dwconv = ops.Conv1d(
|
||||||
|
dim,
|
||||||
|
dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=int(dilation * (kernel_size - 1) / 2),
|
||||||
|
groups=dim,
|
||||||
|
) # depthwise conv
|
||||||
|
self.norm = LayerNorm(dim, eps=1e-6)
|
||||||
|
self.pwconv1 = ops.Linear(
|
||||||
|
dim, int(mlp_ratio * dim)
|
||||||
|
) # pointwise/1x1 convs, implemented with linear layers
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.pwconv2 = ops.Linear(int(mlp_ratio * dim), dim)
|
||||||
|
self.gamma = (
|
||||||
|
nn.Parameter(torch.empty((dim)), requires_grad=False)
|
||||||
|
if layer_scale_init_value > 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.drop_path = DropPath(
|
||||||
|
drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, apply_residual: bool = True):
|
||||||
|
input = x
|
||||||
|
|
||||||
|
x = self.dwconv(x)
|
||||||
|
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.pwconv1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.pwconv2(x)
|
||||||
|
|
||||||
|
if self.gamma is not None:
|
||||||
|
x = comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) * x
|
||||||
|
|
||||||
|
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
|
||||||
|
x = self.drop_path(x)
|
||||||
|
|
||||||
|
if apply_residual:
|
||||||
|
x = input + x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelConvNeXtBlock(nn.Module):
|
||||||
|
def __init__(self, kernel_sizes: List[int], *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
|
||||||
|
for kernel_size in kernel_sizes
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.stack(
|
||||||
|
[block(x, apply_residual=False) for block in self.blocks] + [x],
|
||||||
|
dim=1,
|
||||||
|
).sum(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvNeXtEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_channels=3,
|
||||||
|
depths=[3, 3, 9, 3],
|
||||||
|
dims=[96, 192, 384, 768],
|
||||||
|
drop_path_rate=0.0,
|
||||||
|
layer_scale_init_value=1e-6,
|
||||||
|
kernel_sizes: Tuple[int] = (7,),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert len(depths) == len(dims)
|
||||||
|
|
||||||
|
self.channel_layers = nn.ModuleList()
|
||||||
|
stem = nn.Sequential(
|
||||||
|
ops.Conv1d(
|
||||||
|
input_channels,
|
||||||
|
dims[0],
|
||||||
|
kernel_size=7,
|
||||||
|
padding=3,
|
||||||
|
padding_mode="replicate",
|
||||||
|
),
|
||||||
|
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
||||||
|
)
|
||||||
|
self.channel_layers.append(stem)
|
||||||
|
|
||||||
|
for i in range(len(depths) - 1):
|
||||||
|
mid_layer = nn.Sequential(
|
||||||
|
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
||||||
|
ops.Conv1d(dims[i], dims[i + 1], kernel_size=1),
|
||||||
|
)
|
||||||
|
self.channel_layers.append(mid_layer)
|
||||||
|
|
||||||
|
block_fn = (
|
||||||
|
partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
|
||||||
|
if len(kernel_sizes) == 1
|
||||||
|
else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.stages = nn.ModuleList()
|
||||||
|
drop_path_rates = [
|
||||||
|
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||||
|
]
|
||||||
|
|
||||||
|
cur = 0
|
||||||
|
for i in range(len(depths)):
|
||||||
|
stage = nn.Sequential(
|
||||||
|
*[
|
||||||
|
block_fn(
|
||||||
|
dim=dims[i],
|
||||||
|
drop_path=drop_path_rates[cur + j],
|
||||||
|
layer_scale_init_value=layer_scale_init_value,
|
||||||
|
)
|
||||||
|
for j in range(depths[i])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.stages.append(stage)
|
||||||
|
cur += depths[i]
|
||||||
|
|
||||||
|
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
for channel_layer, stage in zip(self.channel_layers, self.stages):
|
||||||
|
x = channel_layer(x)
|
||||||
|
x = stage(x)
|
||||||
|
|
||||||
|
return self.norm(x)
|
||||||
|
|
||||||
|
|
||||||
|
def get_padding(kernel_size, dilation=1):
|
||||||
|
return (kernel_size * dilation - dilation) // 2
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock1(torch.nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.convs1 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
weight_norm(
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[2],
|
||||||
|
padding=get_padding(kernel_size, dilation[2]),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.convs2 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
weight_norm(
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c1, c2 in zip(self.convs1, self.convs2):
|
||||||
|
xt = F.silu(x)
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = F.silu(xt)
|
||||||
|
xt = c2(xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for conv in self.convs1:
|
||||||
|
remove_weight_norm(conv)
|
||||||
|
for conv in self.convs2:
|
||||||
|
remove_weight_norm(conv)
|
||||||
|
|
||||||
|
|
||||||
|
class HiFiGANGenerator(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
hop_length: int = 512,
|
||||||
|
upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
|
||||||
|
upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
|
||||||
|
resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
|
||||||
|
resblock_dilation_sizes: Tuple[Tuple[int]] = (
|
||||||
|
(1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
||||||
|
num_mels: int = 128,
|
||||||
|
upsample_initial_channel: int = 512,
|
||||||
|
use_template: bool = True,
|
||||||
|
pre_conv_kernel_size: int = 7,
|
||||||
|
post_conv_kernel_size: int = 7,
|
||||||
|
post_activation: Callable = partial(nn.SiLU, inplace=True),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
prod(upsample_rates) == hop_length
|
||||||
|
), f"hop_length must be {prod(upsample_rates)}"
|
||||||
|
|
||||||
|
self.conv_pre = weight_norm(
|
||||||
|
ops.Conv1d(
|
||||||
|
num_mels,
|
||||||
|
upsample_initial_channel,
|
||||||
|
pre_conv_kernel_size,
|
||||||
|
1,
|
||||||
|
padding=get_padding(pre_conv_kernel_size),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_upsamples = len(upsample_rates)
|
||||||
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
|
|
||||||
|
self.noise_convs = nn.ModuleList()
|
||||||
|
self.use_template = use_template
|
||||||
|
self.ups = nn.ModuleList()
|
||||||
|
|
||||||
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
|
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||||
|
self.ups.append(
|
||||||
|
weight_norm(
|
||||||
|
ops.ConvTranspose1d(
|
||||||
|
upsample_initial_channel // (2**i),
|
||||||
|
upsample_initial_channel // (2 ** (i + 1)),
|
||||||
|
k,
|
||||||
|
u,
|
||||||
|
padding=(k - u) // 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not use_template:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if i + 1 < len(upsample_rates):
|
||||||
|
stride_f0 = np.prod(upsample_rates[i + 1:])
|
||||||
|
self.noise_convs.append(
|
||||||
|
ops.Conv1d(
|
||||||
|
1,
|
||||||
|
c_cur,
|
||||||
|
kernel_size=stride_f0 * 2,
|
||||||
|
stride=stride_f0,
|
||||||
|
padding=stride_f0 // 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.noise_convs.append(ops.Conv1d(1, c_cur, kernel_size=1))
|
||||||
|
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||||
|
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
||||||
|
self.resblocks.append(ResBlock1(ch, k, d))
|
||||||
|
|
||||||
|
self.activation_post = post_activation()
|
||||||
|
self.conv_post = weight_norm(
|
||||||
|
ops.Conv1d(
|
||||||
|
ch,
|
||||||
|
1,
|
||||||
|
post_conv_kernel_size,
|
||||||
|
1,
|
||||||
|
padding=get_padding(post_conv_kernel_size),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, template=None):
|
||||||
|
x = self.conv_pre(x)
|
||||||
|
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
x = F.silu(x, inplace=True)
|
||||||
|
x = self.ups[i](x)
|
||||||
|
|
||||||
|
if self.use_template:
|
||||||
|
x = x + self.noise_convs[i](template)
|
||||||
|
|
||||||
|
xs = None
|
||||||
|
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
|
||||||
|
x = self.activation_post(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
x = torch.tanh(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for up in self.ups:
|
||||||
|
remove_weight_norm(up)
|
||||||
|
for block in self.resblocks:
|
||||||
|
block.remove_weight_norm()
|
||||||
|
remove_weight_norm(self.conv_pre)
|
||||||
|
remove_weight_norm(self.conv_post)
|
||||||
|
|
||||||
|
|
||||||
|
class ADaMoSHiFiGANV1(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_channels: int = 128,
|
||||||
|
depths: List[int] = [3, 3, 9, 3],
|
||||||
|
dims: List[int] = [128, 256, 384, 512],
|
||||||
|
drop_path_rate: float = 0.0,
|
||||||
|
kernel_sizes: Tuple[int] = (7,),
|
||||||
|
upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
|
||||||
|
upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
|
||||||
|
resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
|
||||||
|
resblock_dilation_sizes: Tuple[Tuple[int]] = (
|
||||||
|
(1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
||||||
|
num_mels: int = 512,
|
||||||
|
upsample_initial_channel: int = 1024,
|
||||||
|
use_template: bool = False,
|
||||||
|
pre_conv_kernel_size: int = 13,
|
||||||
|
post_conv_kernel_size: int = 13,
|
||||||
|
sampling_rate: int = 44100,
|
||||||
|
n_fft: int = 2048,
|
||||||
|
win_length: int = 2048,
|
||||||
|
hop_length: int = 512,
|
||||||
|
f_min: int = 40,
|
||||||
|
f_max: int = 16000,
|
||||||
|
n_mels: int = 128,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.backbone = ConvNeXtEncoder(
|
||||||
|
input_channels=input_channels,
|
||||||
|
depths=depths,
|
||||||
|
dims=dims,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
kernel_sizes=kernel_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head = HiFiGANGenerator(
|
||||||
|
hop_length=hop_length,
|
||||||
|
upsample_rates=upsample_rates,
|
||||||
|
upsample_kernel_sizes=upsample_kernel_sizes,
|
||||||
|
resblock_kernel_sizes=resblock_kernel_sizes,
|
||||||
|
resblock_dilation_sizes=resblock_dilation_sizes,
|
||||||
|
num_mels=num_mels,
|
||||||
|
upsample_initial_channel=upsample_initial_channel,
|
||||||
|
use_template=use_template,
|
||||||
|
pre_conv_kernel_size=pre_conv_kernel_size,
|
||||||
|
post_conv_kernel_size=post_conv_kernel_size,
|
||||||
|
)
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
self.mel_transform = LogMelSpectrogram(
|
||||||
|
sample_rate=sampling_rate,
|
||||||
|
n_fft=n_fft,
|
||||||
|
win_length=win_length,
|
||||||
|
hop_length=hop_length,
|
||||||
|
f_min=f_min,
|
||||||
|
f_max=f_max,
|
||||||
|
n_mels=n_mels,
|
||||||
|
)
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode(self, mel):
|
||||||
|
y = self.backbone(mel)
|
||||||
|
y = self.head(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def encode(self, x):
|
||||||
|
return self.mel_transform(x)
|
||||||
|
|
||||||
|
def forward(self, mel):
|
||||||
|
y = self.backbone(mel)
|
||||||
|
y = self.head(y)
|
||||||
|
return y
|
@ -39,6 +39,7 @@ import comfy.ldm.wan.model
|
|||||||
import comfy.ldm.hunyuan3d.model
|
import comfy.ldm.hunyuan3d.model
|
||||||
import comfy.ldm.hidream.model
|
import comfy.ldm.hidream.model
|
||||||
import comfy.ldm.chroma.model
|
import comfy.ldm.chroma.model
|
||||||
|
import comfy.ldm.ace.model
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@ -1121,3 +1122,21 @@ class Chroma(Flux):
|
|||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class ACEStep(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['lyric_token_idx'] = comfy.conds.CONDRegular(conditioning_lyrics)
|
||||||
|
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
|
||||||
|
return out
|
||||||
|
@ -226,6 +226,31 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
|
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}genre_embedder.weight'.format(key_prefix) in state_dict_keys: #ACE-Step model
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["audio_model"] = "ace"
|
||||||
|
dit_config["attention_head_dim"] = 128
|
||||||
|
dit_config["in_channels"] = 8
|
||||||
|
dit_config["inner_dim"] = 2560
|
||||||
|
dit_config["max_height"] = 16
|
||||||
|
dit_config["max_position"] = 32768
|
||||||
|
dit_config["max_width"] = 32768
|
||||||
|
dit_config["mlp_ratio"] = 2.5
|
||||||
|
dit_config["num_attention_heads"] = 20
|
||||||
|
dit_config["num_layers"] = 24
|
||||||
|
dit_config["out_channels"] = 8
|
||||||
|
dit_config["patch_size"] = [16, 1]
|
||||||
|
dit_config["rope_theta"] = 1000000.0
|
||||||
|
dit_config["speaker_embedding_dim"] = 512
|
||||||
|
dit_config["text_embedding_dim"] = 768
|
||||||
|
|
||||||
|
dit_config["ssl_encoder_depths"] = [8, 8]
|
||||||
|
dit_config["ssl_latent_dims"] = [1024, 768]
|
||||||
|
dit_config["ssl_names"] = ["mert", "m-hubert"]
|
||||||
|
dit_config["lyric_encoder_vocab_size"] = 6693
|
||||||
|
dit_config["lyric_hidden_size"] = 1024
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
|
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
|
25
comfy/sd.py
25
comfy/sd.py
@ -15,6 +15,7 @@ import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
|||||||
import comfy.ldm.cosmos.vae
|
import comfy.ldm.cosmos.vae
|
||||||
import comfy.ldm.wan.vae
|
import comfy.ldm.wan.vae
|
||||||
import comfy.ldm.hunyuan3d.vae
|
import comfy.ldm.hunyuan3d.vae
|
||||||
|
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||||
import yaml
|
import yaml
|
||||||
import math
|
import math
|
||||||
|
|
||||||
@ -42,6 +43,7 @@ import comfy.text_encoders.cosmos
|
|||||||
import comfy.text_encoders.lumina2
|
import comfy.text_encoders.lumina2
|
||||||
import comfy.text_encoders.wan
|
import comfy.text_encoders.wan
|
||||||
import comfy.text_encoders.hidream
|
import comfy.text_encoders.hidream
|
||||||
|
import comfy.text_encoders.ace
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -437,6 +439,19 @@ class VAE:
|
|||||||
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
|
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
|
||||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
|
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
|
||||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
||||||
|
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (shape[2] * 300) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 72000) * model_management.dtype_size(dtype)
|
||||||
|
self.latent_channels = 8
|
||||||
|
self.output_channels = 2
|
||||||
|
# self.upscale_ratio = 2048
|
||||||
|
# self.downscale_ratio = 2048
|
||||||
|
self.latent_dim = 2
|
||||||
|
self.process_output = lambda audio: audio
|
||||||
|
self.process_input = lambda audio: audio
|
||||||
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
self.disable_offload = True
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
self.first_stage_model = None
|
self.first_stage_model = None
|
||||||
@ -715,6 +730,7 @@ class CLIPType(Enum):
|
|||||||
WAN = 13
|
WAN = 13
|
||||||
HIDREAM = 14
|
HIDREAM = 14
|
||||||
CHROMA = 15
|
CHROMA = 15
|
||||||
|
ACE = 16
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
@ -840,8 +856,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||||
elif te_model == TEModel.T5_BASE:
|
elif te_model == TEModel.T5_BASE:
|
||||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
if clip_type == CLIPType.ACE or "spiece_model" in clip_data[0]:
|
||||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
clip_target.clip = comfy.text_encoders.ace.AceT5Model
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.ace.AceT5Tokenizer
|
||||||
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
|
else:
|
||||||
|
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||||
elif te_model == TEModel.GEMMA_2_2B:
|
elif te_model == TEModel.GEMMA_2_2B:
|
||||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||||
|
@ -17,6 +17,7 @@ import comfy.text_encoders.hunyuan_video
|
|||||||
import comfy.text_encoders.cosmos
|
import comfy.text_encoders.cosmos
|
||||||
import comfy.text_encoders.lumina2
|
import comfy.text_encoders.lumina2
|
||||||
import comfy.text_encoders.wan
|
import comfy.text_encoders.wan
|
||||||
|
import comfy.text_encoders.ace
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@ -1100,6 +1101,34 @@ class Chroma(supported_models_base.BASE):
|
|||||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma]
|
class ACEStep(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"audio_model": "ace",
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 3.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.ACEAudio
|
||||||
|
|
||||||
|
memory_usage_factor = 0.5
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.ACEStep(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
|
||||||
|
|
||||||
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
145
comfy/text_encoders/ace.py
Normal file
145
comfy/text_encoders/ace.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
from .spiece_tokenizer import SPieceTokenizer
|
||||||
|
import comfy.text_encoders.t5
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from tokenizers import Tokenizer
|
||||||
|
from .ace_text_cleaners import multilingual_cleaners
|
||||||
|
|
||||||
|
SUPPORT_LANGUAGES = {
|
||||||
|
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
|
||||||
|
"pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
|
||||||
|
"nl": 297, "ar": 5022, "zh": 5023, "ja": 5412, "hu": 5753,
|
||||||
|
"ko": 6152, "hi": 6680
|
||||||
|
}
|
||||||
|
|
||||||
|
structure_pattern = re.compile(r"\[.*?\]")
|
||||||
|
|
||||||
|
DEFAULT_VOCAB_FILE = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceBpeTokenizer:
|
||||||
|
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
|
||||||
|
self.tokenizer = None
|
||||||
|
if vocab_file is not None:
|
||||||
|
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||||
|
|
||||||
|
def preprocess_text(self, txt, lang):
|
||||||
|
txt = multilingual_cleaners(txt, lang)
|
||||||
|
return txt
|
||||||
|
|
||||||
|
def encode(self, txt, lang='en'):
|
||||||
|
# lang = lang.split("-")[0] # remove the region
|
||||||
|
# self.check_input_length(txt, lang)
|
||||||
|
txt = self.preprocess_text(txt, lang)
|
||||||
|
lang = "zh-cn" if lang == "zh" else lang
|
||||||
|
txt = f"[{lang}]{txt}"
|
||||||
|
txt = txt.replace(" ", "[SPACE]")
|
||||||
|
return self.tokenizer.encode(txt).ids
|
||||||
|
|
||||||
|
def get_lang(self, line):
|
||||||
|
if line.startswith("[") and line[3:4] == ']':
|
||||||
|
lang = line[1:3].lower()
|
||||||
|
if lang in SUPPORT_LANGUAGES:
|
||||||
|
return lang, line[4:]
|
||||||
|
return "en", line
|
||||||
|
|
||||||
|
def __call__(self, string):
|
||||||
|
lines = string.split("\n")
|
||||||
|
lyric_token_idx = [261]
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
lyric_token_idx += [2]
|
||||||
|
continue
|
||||||
|
|
||||||
|
lang, line = self.get_lang(line)
|
||||||
|
|
||||||
|
if lang not in SUPPORT_LANGUAGES:
|
||||||
|
lang = "en"
|
||||||
|
if "zh" in lang:
|
||||||
|
lang = "zh"
|
||||||
|
if "spa" in lang:
|
||||||
|
lang = "es"
|
||||||
|
|
||||||
|
try:
|
||||||
|
if structure_pattern.match(line):
|
||||||
|
token_idx = self.encode(line, "en")
|
||||||
|
else:
|
||||||
|
token_idx = self.encode(line, lang)
|
||||||
|
lyric_token_idx = lyric_token_idx + token_idx + [2]
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("tokenize error {} for line {} major_language {}".format(e, line, lang))
|
||||||
|
return {"input_ids": lyric_token_idx}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(path, **kwargs):
|
||||||
|
return VoiceBpeTokenizer(path, **kwargs)
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class UMT5BaseModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||||
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_base.json")
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=False, model_options=model_options)
|
||||||
|
|
||||||
|
class UMT5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=768, embedding_key='umt5base', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
|
||||||
|
class LyricsTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
|
||||||
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='lyrics', tokenizer_class=VoiceBpeTokenizer, has_start_token=True, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=2, has_end_token=False, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
class AceT5Tokenizer:
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
self.voicebpe = LyricsTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
self.umt5base = UMT5BaseTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
|
out = {}
|
||||||
|
out["lyrics"] = self.voicebpe.tokenize_with_weights(kwargs.get("lyrics", ""), return_word_ids, **kwargs)
|
||||||
|
out["umt5base"] = self.umt5base.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def untokenize(self, token_weight_pair):
|
||||||
|
return self.umt5base.untokenize(token_weight_pair)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return self.umt5base.state_dict()
|
||||||
|
|
||||||
|
class AceT5Model(torch.nn.Module):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.umt5base = UMT5BaseModel(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
self.dtypes = set()
|
||||||
|
if dtype is not None:
|
||||||
|
self.dtypes.add(dtype)
|
||||||
|
|
||||||
|
def set_clip_options(self, options):
|
||||||
|
self.umt5base.set_clip_options(options)
|
||||||
|
|
||||||
|
def reset_clip_options(self):
|
||||||
|
self.umt5base.reset_clip_options()
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
token_weight_pairs_umt5base = token_weight_pairs["umt5base"]
|
||||||
|
token_weight_pairs_lyrics = token_weight_pairs["lyrics"]
|
||||||
|
|
||||||
|
t5_out, t5_pooled = self.umt5base.encode_token_weights(token_weight_pairs_umt5base)
|
||||||
|
|
||||||
|
lyrics_embeds = torch.tensor(list(map(lambda a: a[0], token_weight_pairs_lyrics[0]))).unsqueeze(0)
|
||||||
|
return t5_out, None, {"conditioning_lyrics": lyrics_embeds}
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
return self.umt5base.load_sd(sd)
|
15535
comfy/text_encoders/ace_lyrics_tokenizer/vocab.json
Normal file
15535
comfy/text_encoders/ace_lyrics_tokenizer/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
270
comfy/text_encoders/ace_text_cleaners.py
Normal file
270
comfy/text_encoders/ace_text_cleaners.py
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
# basic text cleaners for the ACE step model
|
||||||
|
# I didn't copy the ones from the reference code because I didn't want to deal with the dependencies
|
||||||
|
# TODO: more languages than english?
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
def number_to_text(num, ordinal=False):
|
||||||
|
"""
|
||||||
|
Convert a number (int or float) to its text representation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num: The number to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Text representation of the number
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(num, (int, float)):
|
||||||
|
return "Input must be a number"
|
||||||
|
|
||||||
|
# Handle special case of zero
|
||||||
|
if num == 0:
|
||||||
|
return "zero"
|
||||||
|
|
||||||
|
# Handle negative numbers
|
||||||
|
negative = num < 0
|
||||||
|
num = abs(num)
|
||||||
|
|
||||||
|
# Handle floats
|
||||||
|
if isinstance(num, float):
|
||||||
|
# Split into integer and decimal parts
|
||||||
|
int_part = int(num)
|
||||||
|
|
||||||
|
# Convert both parts
|
||||||
|
int_text = _int_to_text(int_part)
|
||||||
|
|
||||||
|
# Handle decimal part (convert to string and remove '0.')
|
||||||
|
decimal_str = str(num).split('.')[1]
|
||||||
|
decimal_text = " point " + " ".join(_digit_to_text(int(digit)) for digit in decimal_str)
|
||||||
|
|
||||||
|
result = int_text + decimal_text
|
||||||
|
else:
|
||||||
|
# Handle integers
|
||||||
|
result = _int_to_text(num)
|
||||||
|
|
||||||
|
# Add 'negative' prefix for negative numbers
|
||||||
|
if negative:
|
||||||
|
result = "negative " + result
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _int_to_text(num):
|
||||||
|
"""Helper function to convert an integer to text"""
|
||||||
|
|
||||||
|
ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
|
||||||
|
"ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen",
|
||||||
|
"seventeen", "eighteen", "nineteen"]
|
||||||
|
|
||||||
|
tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]
|
||||||
|
|
||||||
|
if num < 20:
|
||||||
|
return ones[num]
|
||||||
|
|
||||||
|
if num < 100:
|
||||||
|
return tens[num // 10] + (" " + ones[num % 10] if num % 10 != 0 else "")
|
||||||
|
|
||||||
|
if num < 1000:
|
||||||
|
return ones[num // 100] + " hundred" + (" " + _int_to_text(num % 100) if num % 100 != 0 else "")
|
||||||
|
|
||||||
|
if num < 1000000:
|
||||||
|
return _int_to_text(num // 1000) + " thousand" + (" " + _int_to_text(num % 1000) if num % 1000 != 0 else "")
|
||||||
|
|
||||||
|
if num < 1000000000:
|
||||||
|
return _int_to_text(num // 1000000) + " million" + (" " + _int_to_text(num % 1000000) if num % 1000000 != 0 else "")
|
||||||
|
|
||||||
|
return _int_to_text(num // 1000000000) + " billion" + (" " + _int_to_text(num % 1000000000) if num % 1000000000 != 0 else "")
|
||||||
|
|
||||||
|
|
||||||
|
def _digit_to_text(digit):
|
||||||
|
"""Convert a single digit to text"""
|
||||||
|
digits = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
||||||
|
return digits[digit]
|
||||||
|
|
||||||
|
|
||||||
|
_whitespace_re = re.compile(r"\s+")
|
||||||
|
|
||||||
|
|
||||||
|
# List of (regular expression, replacement) pairs for abbreviations:
|
||||||
|
_abbreviations = {
|
||||||
|
"en": [
|
||||||
|
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||||
|
for x in [
|
||||||
|
("mrs", "misess"),
|
||||||
|
("mr", "mister"),
|
||||||
|
("dr", "doctor"),
|
||||||
|
("st", "saint"),
|
||||||
|
("co", "company"),
|
||||||
|
("jr", "junior"),
|
||||||
|
("maj", "major"),
|
||||||
|
("gen", "general"),
|
||||||
|
("drs", "doctors"),
|
||||||
|
("rev", "reverend"),
|
||||||
|
("lt", "lieutenant"),
|
||||||
|
("hon", "honorable"),
|
||||||
|
("sgt", "sergeant"),
|
||||||
|
("capt", "captain"),
|
||||||
|
("esq", "esquire"),
|
||||||
|
("ltd", "limited"),
|
||||||
|
("col", "colonel"),
|
||||||
|
("ft", "fort"),
|
||||||
|
]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def expand_abbreviations_multilingual(text, lang="en"):
|
||||||
|
for regex, replacement in _abbreviations[lang]:
|
||||||
|
text = re.sub(regex, replacement, text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
_symbols_multilingual = {
|
||||||
|
"en": [
|
||||||
|
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||||
|
for x in [
|
||||||
|
("&", " and "),
|
||||||
|
("@", " at "),
|
||||||
|
("%", " percent "),
|
||||||
|
("#", " hash "),
|
||||||
|
("$", " dollar "),
|
||||||
|
("£", " pound "),
|
||||||
|
("°", " degree "),
|
||||||
|
]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def expand_symbols_multilingual(text, lang="en"):
|
||||||
|
for regex, replacement in _symbols_multilingual[lang]:
|
||||||
|
text = re.sub(regex, replacement, text)
|
||||||
|
text = text.replace(" ", " ") # Ensure there are no double spaces
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
_ordinal_re = {
|
||||||
|
"en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
|
||||||
|
}
|
||||||
|
_number_re = re.compile(r"[0-9]+")
|
||||||
|
_currency_re = {
|
||||||
|
"USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
|
||||||
|
"GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
|
||||||
|
"EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
|
||||||
|
}
|
||||||
|
|
||||||
|
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
|
||||||
|
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
|
||||||
|
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_commas(m):
|
||||||
|
text = m.group(0)
|
||||||
|
if "," in text:
|
||||||
|
text = text.replace(",", "")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_dots(m):
|
||||||
|
text = m.group(0)
|
||||||
|
if "." in text:
|
||||||
|
text = text.replace(".", "")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_decimal_point(m, lang="en"):
|
||||||
|
amount = m.group(1).replace(",", ".")
|
||||||
|
return number_to_text(float(amount))
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_currency(m, lang="en", currency="USD"):
|
||||||
|
amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
|
||||||
|
full_amount = number_to_text(amount)
|
||||||
|
|
||||||
|
and_equivalents = {
|
||||||
|
"en": ", ",
|
||||||
|
"es": " con ",
|
||||||
|
"fr": " et ",
|
||||||
|
"de": " und ",
|
||||||
|
"pt": " e ",
|
||||||
|
"it": " e ",
|
||||||
|
"pl": ", ",
|
||||||
|
"cs": ", ",
|
||||||
|
"ru": ", ",
|
||||||
|
"nl": ", ",
|
||||||
|
"ar": ", ",
|
||||||
|
"tr": ", ",
|
||||||
|
"hu": ", ",
|
||||||
|
"ko": ", ",
|
||||||
|
}
|
||||||
|
|
||||||
|
if amount.is_integer():
|
||||||
|
last_and = full_amount.rfind(and_equivalents[lang])
|
||||||
|
if last_and != -1:
|
||||||
|
full_amount = full_amount[:last_and]
|
||||||
|
|
||||||
|
return full_amount
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_ordinal(m, lang="en"):
|
||||||
|
return number_to_text(int(m.group(1)), ordinal=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_number(m, lang="en"):
|
||||||
|
return number_to_text(int(m.group(0)))
|
||||||
|
|
||||||
|
|
||||||
|
def expand_numbers_multilingual(text, lang="en"):
|
||||||
|
if lang in ["en", "ru"]:
|
||||||
|
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||||
|
else:
|
||||||
|
text = re.sub(_dot_number_re, _remove_dots, text)
|
||||||
|
try:
|
||||||
|
text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
|
||||||
|
text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
|
||||||
|
text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
|
||||||
|
text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
|
||||||
|
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def lowercase(text):
|
||||||
|
return text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def collapse_whitespace(text):
|
||||||
|
return re.sub(_whitespace_re, " ", text)
|
||||||
|
|
||||||
|
|
||||||
|
def multilingual_cleaners(text, lang):
|
||||||
|
text = text.replace('"', "")
|
||||||
|
if lang == "tr":
|
||||||
|
text = text.replace("İ", "i")
|
||||||
|
text = text.replace("Ö", "ö")
|
||||||
|
text = text.replace("Ü", "ü")
|
||||||
|
text = lowercase(text)
|
||||||
|
try:
|
||||||
|
text = expand_numbers_multilingual(text, lang)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
text = expand_abbreviations_multilingual(text, lang)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
text = expand_symbols_multilingual(text, lang=lang)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
text = collapse_whitespace(text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def basic_cleaners(text):
|
||||||
|
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||||
|
text = lowercase(text)
|
||||||
|
text = collapse_whitespace(text)
|
||||||
|
return text
|
22
comfy/text_encoders/umt5_config_base.json
Normal file
22
comfy/text_encoders/umt5_config_base.json
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"d_ff": 2048,
|
||||||
|
"d_kv": 64,
|
||||||
|
"d_model": 768,
|
||||||
|
"decoder_start_token_id": 0,
|
||||||
|
"dropout_rate": 0.1,
|
||||||
|
"eos_token_id": 1,
|
||||||
|
"dense_act_fn": "gelu_pytorch_tanh",
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"is_encoder_decoder": true,
|
||||||
|
"is_gated_act": true,
|
||||||
|
"layer_norm_epsilon": 1e-06,
|
||||||
|
"model_type": "umt5",
|
||||||
|
"num_decoder_layers": 12,
|
||||||
|
"num_heads": 12,
|
||||||
|
"num_layers": 12,
|
||||||
|
"output_past": true,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"relative_attention_num_buckets": 32,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"vocab_size": 256384
|
||||||
|
}
|
46
comfy_extras/nodes_ace.py
Normal file
46
comfy_extras/nodes_ace.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncodeAceStepAudio:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"clip": ("CLIP", ),
|
||||||
|
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
|
def encode(self, clip, tags, lyrics):
|
||||||
|
tokens = clip.tokenize(tags, lyrics=lyrics)
|
||||||
|
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyAceStepLatentAudio:
|
||||||
|
def __init__(self):
|
||||||
|
self.device = comfy.model_management.intermediate_device()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
CATEGORY = "latent/audio"
|
||||||
|
|
||||||
|
def generate(self, seconds, batch_size):
|
||||||
|
length = int(seconds * 44100 / 512 / 8)
|
||||||
|
latent = torch.zeros([batch_size, 8, 16, length], device=self.device)
|
||||||
|
return ({"samples": latent, "type": "audio"}, )
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TextEncodeAceStepAudio": TextEncodeAceStepAudio,
|
||||||
|
"EmptyAceStepLatentAudio": EmptyAceStepLatentAudio,
|
||||||
|
}
|
6
nodes.py
6
nodes.py
@ -246,6 +246,9 @@ class ConditioningZeroOut:
|
|||||||
pooled_output = d.get("pooled_output", None)
|
pooled_output = d.get("pooled_output", None)
|
||||||
if pooled_output is not None:
|
if pooled_output is not None:
|
||||||
d["pooled_output"] = torch.zeros_like(pooled_output)
|
d["pooled_output"] = torch.zeros_like(pooled_output)
|
||||||
|
conditioning_lyrics = d.get("conditioning_lyrics", None)
|
||||||
|
if conditioning_lyrics is not None:
|
||||||
|
d["conditioning_lyrics"] = torch.zeros_like(conditioning_lyrics)
|
||||||
n = [torch.zeros_like(t[0]), d]
|
n = [torch.zeros_like(t[0]), d]
|
||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
@ -917,7 +920,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -2259,6 +2262,7 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_hidream.py",
|
"nodes_hidream.py",
|
||||||
"nodes_fresca.py",
|
"nodes_fresca.py",
|
||||||
"nodes_preview_any.py",
|
"nodes_preview_any.py",
|
||||||
|
"nodes_ace.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
Loading…
x
Reference in New Issue
Block a user