From 8fe2dea29729a52d0d7c86342a0b25e216efea5c Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Aug 2025 21:23:03 -0700 Subject: [PATCH] Made CosmosVideo work with optimized_attention_override --- comfy/ldm/cosmos/blocks.py | 10 +++++++++- comfy/ldm/cosmos/model.py | 2 ++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index 5c4356a3f..afb43d469 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -176,6 +176,7 @@ class Attention(nn.Module): context=None, mask=None, rope_emb=None, + transformer_options={}, **kwargs, ): """ @@ -184,7 +185,7 @@ class Attention(nn.Module): context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None """ q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) - out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True) + out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options) del q, k, v out = rearrange(out, " b n s c -> s b (n c)") return self.to_out(out) @@ -546,6 +547,7 @@ class VideoAttn(nn.Module): context: Optional[torch.Tensor] = None, crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Forward pass for video attention. @@ -571,6 +573,7 @@ class VideoAttn(nn.Module): context_M_B_D, crossattn_mask, rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W) return x_T_H_W_B_D @@ -665,6 +668,7 @@ class DITBuildingBlock(nn.Module): crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Forward pass for dynamically configured blocks with adaptive normalization. @@ -702,6 +706,7 @@ class DITBuildingBlock(nn.Module): adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), context=None, rope_emb_L_1_1_D=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) elif self.block_type in ["cross_attn", "ca"]: x = x + gate_1_1_1_B_D * self.block( @@ -709,6 +714,7 @@ class DITBuildingBlock(nn.Module): context=crossattn_emb, crossattn_mask=crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) else: raise ValueError(f"Unknown block type: {self.block_type}") @@ -784,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module): crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: for block in self.blocks: x = block( @@ -793,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module): crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, adaln_lora_B_3D=adaln_lora_B_3D, + transformer_options=transformer_options, ) return x diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py index 53698b758..52ef7ef43 100644 --- a/comfy/ldm/cosmos/model.py +++ b/comfy/ldm/cosmos/model.py @@ -520,6 +520,7 @@ class GeneralDIT(nn.Module): x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" + transformer_options = kwargs.get("transformer_options", {}) for _, block in self.blocks.items(): assert ( self.blocks["block0"].x_format == block.x_format @@ -534,6 +535,7 @@ class GeneralDIT(nn.Module): crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, adaln_lora_B_3D=adaln_lora_B_3D, + transformer_options=transformer_options, ) x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")