mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-15 05:57:57 +00:00
Made CosmosVideo work with optimized_attention_override
This commit is contained in:
@@ -176,6 +176,7 @@ class Attention(nn.Module):
|
|||||||
context=None,
|
context=None,
|
||||||
mask=None,
|
mask=None,
|
||||||
rope_emb=None,
|
rope_emb=None,
|
||||||
|
transformer_options={},
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -184,7 +185,7 @@ class Attention(nn.Module):
|
|||||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||||
"""
|
"""
|
||||||
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
out = rearrange(out, " b n s c -> s b (n c)")
|
out = rearrange(out, " b n s c -> s b (n c)")
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
@@ -546,6 +547,7 @@ class VideoAttn(nn.Module):
|
|||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
crossattn_mask: Optional[torch.Tensor] = None,
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
|
transformer_options: Optional[dict] = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass for video attention.
|
Forward pass for video attention.
|
||||||
@@ -571,6 +573,7 @@ class VideoAttn(nn.Module):
|
|||||||
context_M_B_D,
|
context_M_B_D,
|
||||||
crossattn_mask,
|
crossattn_mask,
|
||||||
rope_emb=rope_emb_L_1_1_D,
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
|
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
|
||||||
return x_T_H_W_B_D
|
return x_T_H_W_B_D
|
||||||
@@ -665,6 +668,7 @@ class DITBuildingBlock(nn.Module):
|
|||||||
crossattn_mask: Optional[torch.Tensor] = None,
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
transformer_options: Optional[dict] = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass for dynamically configured blocks with adaptive normalization.
|
Forward pass for dynamically configured blocks with adaptive normalization.
|
||||||
@@ -702,6 +706,7 @@ class DITBuildingBlock(nn.Module):
|
|||||||
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
||||||
context=None,
|
context=None,
|
||||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
elif self.block_type in ["cross_attn", "ca"]:
|
elif self.block_type in ["cross_attn", "ca"]:
|
||||||
x = x + gate_1_1_1_B_D * self.block(
|
x = x + gate_1_1_1_B_D * self.block(
|
||||||
@@ -709,6 +714,7 @@ class DITBuildingBlock(nn.Module):
|
|||||||
context=crossattn_emb,
|
context=crossattn_emb,
|
||||||
crossattn_mask=crossattn_mask,
|
crossattn_mask=crossattn_mask,
|
||||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown block type: {self.block_type}")
|
raise ValueError(f"Unknown block type: {self.block_type}")
|
||||||
@@ -784,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module):
|
|||||||
crossattn_mask: Optional[torch.Tensor] = None,
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
|
transformer_options: Optional[dict] = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(
|
x = block(
|
||||||
@@ -793,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module):
|
|||||||
crossattn_mask,
|
crossattn_mask,
|
||||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
return x
|
return x
|
||||||
|
@@ -520,6 +520,7 @@ class GeneralDIT(nn.Module):
|
|||||||
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||||
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
|
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
|
||||||
|
|
||||||
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
for _, block in self.blocks.items():
|
for _, block in self.blocks.items():
|
||||||
assert (
|
assert (
|
||||||
self.blocks["block0"].x_format == block.x_format
|
self.blocks["block0"].x_format == block.x_format
|
||||||
@@ -534,6 +535,7 @@ class GeneralDIT(nn.Module):
|
|||||||
crossattn_mask,
|
crossattn_mask,
|
||||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||||
|
Reference in New Issue
Block a user