mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-15 05:57:57 +00:00
Make CosmosPredict2 work with optimized_attention_override
This commit is contained in:
@@ -44,7 +44,7 @@ class GPT2FeedForward(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
||||||
"""Computes multi-head attention using PyTorch's native implementation.
|
"""Computes multi-head attention using PyTorch's native implementation.
|
||||||
|
|
||||||
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
||||||
@@ -71,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
|
|||||||
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
||||||
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||||
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||||
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
|
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
@@ -180,8 +180,8 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
return q, k, v
|
return q, k, v
|
||||||
|
|
||||||
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
|
||||||
result = self.attn_op(q, k, v) # [B, S, H, D]
|
result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
|
||||||
return self.output_dropout(self.output_proj(result))
|
return self.output_dropout(self.output_proj(result))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -189,6 +189,7 @@ class Attention(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
rope_emb: Optional[torch.Tensor] = None,
|
rope_emb: Optional[torch.Tensor] = None,
|
||||||
|
transformer_options: Optional[dict] = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -196,7 +197,7 @@ class Attention(nn.Module):
|
|||||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||||
"""
|
"""
|
||||||
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
||||||
return self.compute_attention(q, k, v)
|
return self.compute_attention(q, k, v, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
|
||||||
class Timesteps(nn.Module):
|
class Timesteps(nn.Module):
|
||||||
@@ -459,6 +460,7 @@ class Block(nn.Module):
|
|||||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||||
|
transformer_options: Optional[dict] = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if extra_per_block_pos_emb is not None:
|
if extra_per_block_pos_emb is not None:
|
||||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||||
@@ -512,6 +514,7 @@ class Block(nn.Module):
|
|||||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||||
None,
|
None,
|
||||||
rope_emb=rope_emb_L_1_1_D,
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
|
transformer_options=transformer_options,
|
||||||
),
|
),
|
||||||
"b (t h w) d -> b t h w d",
|
"b (t h w) d -> b t h w d",
|
||||||
t=T,
|
t=T,
|
||||||
@@ -525,6 +528,7 @@ class Block(nn.Module):
|
|||||||
layer_norm_cross_attn: Callable,
|
layer_norm_cross_attn: Callable,
|
||||||
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||||
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||||
|
transformer_options: Optional[dict] = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
_normalized_x_B_T_H_W_D = _fn(
|
_normalized_x_B_T_H_W_D = _fn(
|
||||||
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
||||||
@@ -534,6 +538,7 @@ class Block(nn.Module):
|
|||||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||||
crossattn_emb,
|
crossattn_emb,
|
||||||
rope_emb=rope_emb_L_1_1_D,
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
|
transformer_options=transformer_options,
|
||||||
),
|
),
|
||||||
"b (t h w) d -> b t h w d",
|
"b (t h w) d -> b t h w d",
|
||||||
t=T,
|
t=T,
|
||||||
@@ -547,6 +552,7 @@ class Block(nn.Module):
|
|||||||
self.layer_norm_cross_attn,
|
self.layer_norm_cross_attn,
|
||||||
scale_cross_attn_B_T_1_1_D,
|
scale_cross_attn_B_T_1_1_D,
|
||||||
shift_cross_attn_B_T_1_1_D,
|
shift_cross_attn_B_T_1_1_D,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||||
|
|
||||||
@@ -865,6 +871,7 @@ class MiniTrainDIT(nn.Module):
|
|||||||
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
||||||
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
||||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||||
|
"transformer_options": kwargs.get("transformer_options", {}),
|
||||||
}
|
}
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x_B_T_H_W_D = block(
|
x_B_T_H_W_D = block(
|
||||||
|
Reference in New Issue
Block a user