mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 03:58:22 +00:00
Initial ACE-Step model implementation. (#7972)
This commit is contained in:
381
comfy/ldm/ace/model.py
Normal file
381
comfy/ldm/ace/model.py
Normal file
@@ -0,0 +1,381 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/ace_step_transformer.py
|
||||
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, List, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from .attention import LinearTransformerBlock, t2i_modulate
|
||||
from .lyric_encoder import ConformerEncoder as LyricEncoder
|
||||
|
||||
|
||||
def cross_norm(hidden_states, controlnet_input):
|
||||
# input N x T x c
|
||||
mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
|
||||
mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
|
||||
controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
|
||||
return controlnet_input
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
|
||||
class Qwen2RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, dtype=None, device=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=device).float() / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
class T2IFinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of Sana.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(2, hidden_size, dtype=dtype, device=device))
|
||||
self.out_channels = out_channels
|
||||
self.patch_size = patch_size
|
||||
|
||||
def unpatchfy(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
width: int,
|
||||
):
|
||||
# 4 unpatchify
|
||||
new_height, new_width = 1, hidden_states.size(1)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
|
||||
).contiguous()
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
|
||||
).contiguous()
|
||||
if width > new_width:
|
||||
output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
|
||||
elif width < new_width:
|
||||
output = output[:, :, :, :width]
|
||||
return output
|
||||
|
||||
def forward(self, x, t, output_length):
|
||||
shift, scale = (comfy.model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
|
||||
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
# unpatchify
|
||||
output = self.unpatchfy(x, output_length)
|
||||
return output
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
height=16,
|
||||
width=4096,
|
||||
patch_size=(16, 1),
|
||||
in_channels=8,
|
||||
embed_dim=1152,
|
||||
bias=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
patch_size_h, patch_size_w = patch_size
|
||||
self.early_conv_layers = nn.Sequential(
|
||||
operations.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias, dtype=dtype, device=device),
|
||||
operations.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True, dtype=dtype, device=device),
|
||||
operations.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias, dtype=dtype, device=device)
|
||||
)
|
||||
self.patch_size = patch_size
|
||||
self.height, self.width = height // patch_size_h, width // patch_size_w
|
||||
self.base_size = self.width
|
||||
|
||||
def forward(self, latent):
|
||||
# early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
|
||||
latent = self.early_conv_layers(latent)
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
return latent
|
||||
|
||||
|
||||
class ACEStepTransformer2DModel(nn.Module):
|
||||
# _supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: Optional[int] = 8,
|
||||
num_layers: int = 28,
|
||||
inner_dim: int = 1536,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 24,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_channels: int = 8,
|
||||
max_position: int = 32768,
|
||||
rope_theta: float = 1000000.0,
|
||||
speaker_embedding_dim: int = 512,
|
||||
text_embedding_dim: int = 768,
|
||||
ssl_encoder_depths: List[int] = [9, 9],
|
||||
ssl_names: List[str] = ["mert", "m-hubert"],
|
||||
ssl_latent_dims: List[int] = [1024, 768],
|
||||
lyric_encoder_vocab_size: int = 6681,
|
||||
lyric_hidden_size: int = 1024,
|
||||
patch_size: List[int] = [16, 1],
|
||||
max_height: int = 16,
|
||||
max_width: int = 4096,
|
||||
audio_model=None,
|
||||
dtype=None, device=None, operations=None
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dtype = dtype
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.inner_dim = inner_dim
|
||||
self.out_channels = out_channels
|
||||
self.max_position = max_position
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.rotary_emb = Qwen2RotaryEmbedding(
|
||||
dim=self.attention_head_dim,
|
||||
max_position_embeddings=self.max_position,
|
||||
base=self.rope_theta,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.num_layers = num_layers
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
LinearTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
add_cross_attention=True,
|
||||
add_cross_attention_dim=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.t_block = nn.Sequential(nn.SiLU(), operations.Linear(self.inner_dim, 6 * self.inner_dim, bias=True, dtype=dtype, device=device))
|
||||
|
||||
# speaker
|
||||
self.speaker_embedder = operations.Linear(speaker_embedding_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
# genre
|
||||
self.genre_embedder = operations.Linear(text_embedding_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
# lyric
|
||||
self.lyric_embs = operations.Embedding(lyric_encoder_vocab_size, lyric_hidden_size, dtype=dtype, device=device)
|
||||
self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0, dtype=dtype, device=device, operations=operations)
|
||||
self.lyric_proj = operations.Linear(lyric_hidden_size, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
projector_dim = 2 * self.inner_dim
|
||||
|
||||
self.projectors = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
operations.Linear(self.inner_dim, projector_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(projector_dim, projector_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(projector_dim, ssl_dim, dtype=dtype, device=device),
|
||||
) for ssl_dim in ssl_latent_dims
|
||||
])
|
||||
|
||||
self.proj_in = PatchEmbed(
|
||||
height=max_height,
|
||||
width=max_width,
|
||||
patch_size=patch_size,
|
||||
embed_dim=self.inner_dim,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_lyric_encoder(
|
||||
self,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
out_dtype=None,
|
||||
):
|
||||
# N x T x D
|
||||
lyric_embs = self.lyric_embs(lyric_token_idx, out_dtype=out_dtype)
|
||||
prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
|
||||
prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
|
||||
return prompt_prenet_out
|
||||
|
||||
def encode(
|
||||
self,
|
||||
encoder_text_hidden_states: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
|
||||
bs = encoder_text_hidden_states.shape[0]
|
||||
device = encoder_text_hidden_states.device
|
||||
|
||||
# speaker embedding
|
||||
encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
|
||||
|
||||
# genre embedding
|
||||
encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
|
||||
|
||||
# lyric
|
||||
encoder_lyric_hidden_states = self.forward_lyric_encoder(
|
||||
lyric_token_idx=lyric_token_idx,
|
||||
lyric_mask=lyric_mask,
|
||||
out_dtype=encoder_text_hidden_states.dtype,
|
||||
)
|
||||
|
||||
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
|
||||
|
||||
encoder_hidden_mask = None
|
||||
if text_attention_mask is not None:
|
||||
speaker_mask = torch.ones(bs, 1, device=device)
|
||||
encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
|
||||
|
||||
return encoder_hidden_states, encoder_hidden_mask
|
||||
|
||||
def decode(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_mask: torch.Tensor,
|
||||
timestep: Optional[torch.Tensor],
|
||||
output_length: int = 0,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
||||
temb = self.t_block(embedded_timestep)
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# controlnet logic
|
||||
if block_controlnet_hidden_states is not None:
|
||||
control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
|
||||
hidden_states = hidden_states + control_condi * controlnet_scale
|
||||
|
||||
# inner_hidden_states = []
|
||||
|
||||
rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
|
||||
encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_hidden_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
||||
temb=temb,
|
||||
)
|
||||
|
||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
attention_mask=None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
**kwargs
|
||||
):
|
||||
hidden_states = x
|
||||
encoder_text_hidden_states = context
|
||||
encoder_hidden_states, encoder_hidden_mask = self.encode(
|
||||
encoder_text_hidden_states=encoder_text_hidden_states,
|
||||
text_attention_mask=text_attention_mask,
|
||||
speaker_embeds=speaker_embeds,
|
||||
lyric_token_idx=lyric_token_idx,
|
||||
lyric_mask=lyric_mask,
|
||||
)
|
||||
|
||||
output_length = hidden_states.shape[-1]
|
||||
|
||||
output = self.decode(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_mask=encoder_hidden_mask,
|
||||
timestep=timestep,
|
||||
output_length=output_length,
|
||||
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
||||
controlnet_scale=controlnet_scale,
|
||||
)
|
||||
|
||||
return output
|
Reference in New Issue
Block a user