Experimental lyrics strength for ACE. (#7984)

This commit is contained in:
comfyanonymous
2025-05-07 16:22:07 -07:00
committed by GitHub
parent b9980592c4
commit cc33cd3422
3 changed files with 12 additions and 4 deletions

View File

@@ -273,6 +273,7 @@ class ACEStepTransformer2DModel(nn.Module):
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
lyrics_strength=1.0,
):
bs = encoder_text_hidden_states.shape[0]
@@ -291,6 +292,8 @@ class ACEStepTransformer2DModel(nn.Module):
out_dtype=encoder_text_hidden_states.dtype,
)
encoder_lyric_hidden_states *= lyrics_strength
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
encoder_hidden_mask = None
@@ -310,7 +313,6 @@ class ACEStepTransformer2DModel(nn.Module):
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
return_dict: bool = True,
):
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
temb = self.t_block(embedded_timestep)
@@ -353,6 +355,7 @@ class ACEStepTransformer2DModel(nn.Module):
lyric_mask: Optional[torch.LongTensor] = None,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
lyrics_strength=1.0,
**kwargs
):
hidden_states = x
@@ -363,6 +366,7 @@ class ACEStepTransformer2DModel(nn.Module):
speaker_embeds=speaker_embeds,
lyric_token_idx=lyric_token_idx,
lyric_mask=lyric_mask,
lyrics_strength=lyrics_strength,
)
output_length = hidden_states.shape[-1]