mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 05:25:23 +00:00
Made optimized_attention_override work with ACE Step
This commit is contained in:
@@ -133,6 +133,7 @@ class Attention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
transformer_options={},
|
||||||
**cross_attention_kwargs,
|
**cross_attention_kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.processor(
|
return self.processor(
|
||||||
@@ -140,6 +141,7 @@ class Attention(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
transformer_options=transformer_options,
|
||||||
**cross_attention_kwargs,
|
**cross_attention_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -366,6 +368,7 @@ class CustomerAttnProcessor2_0:
|
|||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||||
|
transformer_options={},
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -433,7 +436,7 @@ class CustomerAttnProcessor2_0:
|
|||||||
|
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
hidden_states = optimized_attention(
|
hidden_states = optimized_attention(
|
||||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
|
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
|
||||||
).to(query.dtype)
|
).to(query.dtype)
|
||||||
|
|
||||||
# linear proj
|
# linear proj
|
||||||
@@ -697,6 +700,7 @@ class LinearTransformerBlock(nn.Module):
|
|||||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||||
temb: torch.FloatTensor = None,
|
temb: torch.FloatTensor = None,
|
||||||
|
transformer_options={},
|
||||||
):
|
):
|
||||||
|
|
||||||
N = hidden_states.shape[0]
|
N = hidden_states.shape[0]
|
||||||
@@ -720,6 +724,7 @@ class LinearTransformerBlock(nn.Module):
|
|||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
rotary_freqs_cis=rotary_freqs_cis,
|
rotary_freqs_cis=rotary_freqs_cis,
|
||||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_output, _ = self.attn(
|
attn_output, _ = self.attn(
|
||||||
@@ -729,6 +734,7 @@ class LinearTransformerBlock(nn.Module):
|
|||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
rotary_freqs_cis=rotary_freqs_cis,
|
rotary_freqs_cis=rotary_freqs_cis,
|
||||||
rotary_freqs_cis_cross=None,
|
rotary_freqs_cis_cross=None,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_adaln_single:
|
if self.use_adaln_single:
|
||||||
@@ -743,6 +749,7 @@ class LinearTransformerBlock(nn.Module):
|
|||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
rotary_freqs_cis=rotary_freqs_cis,
|
rotary_freqs_cis=rotary_freqs_cis,
|
||||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
hidden_states = attn_output + hidden_states
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
|
@@ -314,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
|||||||
output_length: int = 0,
|
output_length: int = 0,
|
||||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||||
|
transformer_options={},
|
||||||
):
|
):
|
||||||
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
||||||
temb = self.t_block(embedded_timestep)
|
temb = self.t_block(embedded_timestep)
|
||||||
@@ -339,6 +340,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
|||||||
rotary_freqs_cis=rotary_freqs_cis,
|
rotary_freqs_cis=rotary_freqs_cis,
|
||||||
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||||
@@ -393,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
|||||||
|
|
||||||
output_length = hidden_states.shape[-1]
|
output_length = hidden_states.shape[-1]
|
||||||
|
|
||||||
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
output = self.decode(
|
output = self.decode(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@@ -402,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module):
|
|||||||
output_length=output_length,
|
output_length=output_length,
|
||||||
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
||||||
controlnet_scale=controlnet_scale,
|
controlnet_scale=controlnet_scale,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
Reference in New Issue
Block a user