From dd21b4aa51a346396f442970728b4c6067900b03 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 27 Aug 2025 17:56:21 -0700 Subject: [PATCH] Made WAN attention receive transformer_options, test node added to wan to test out attention override later --- comfy/ldm/wan/model.py | 14 ++++++++------ comfy_extras/nodes_wan.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index dedfb47e2..7627da643 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -52,7 +52,7 @@ class WanSelfAttention(nn.Module): self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, freqs): + def forward(self, x, freqs, transformer_options={}): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] @@ -75,6 +75,7 @@ class WanSelfAttention(nn.Module): k.view(b, s, n * d), v, heads=self.num_heads, + transformer_options=transformer_options, ) x = self.o(x) @@ -83,7 +84,7 @@ class WanSelfAttention(nn.Module): class WanT2VCrossAttention(WanSelfAttention): - def forward(self, x, context, **kwargs): + def forward(self, x, context, transformer_options={}, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] @@ -95,7 +96,7 @@ class WanT2VCrossAttention(WanSelfAttention): v = self.v(context) # compute attention - x = optimized_attention(q, k, v, heads=self.num_heads) + x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options) x = self.o(x) return x @@ -206,6 +207,7 @@ class WanAttentionBlock(nn.Module): freqs, context, context_img_len=257, + transformer_options={}, ): r""" Args: @@ -224,12 +226,12 @@ class WanAttentionBlock(nn.Module): # self-attention y = self.self_attn( torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), - freqs) + freqs, transformer_options=transformer_options) x = torch.addcmul(x, y, repeat_e(e[2], x)) # cross-attention & ffn - x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) x = torch.addcmul(x, y, repeat_e(e[5], x)) return x @@ -564,7 +566,7 @@ class WanModel(torch.nn.Module): out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) # head x = self.head(x, e) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 312260f00..4fedbba7d 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1003,6 +1003,42 @@ class Wan22ImageToVideoLatent(io.ComfyNode): out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) return io.NodeOutput(out_latent) +import comfy.patcher_extension +import comfy.ldm.modules.attention +class AttentionOverrideTest(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="AttentionOverrideTest", + category="devtools", + inputs=[ + io.Model.Input("model"), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @staticmethod + def attention_override(func, transformer_options, *args, **kwargs): + new_attention = comfy.ldm.modules.attention.attention_basic + return new_attention.__wrapped__(*args, **kwargs) + + @staticmethod + def sampler_sampler_wrapper(executor, *args, **kwargs): + try: + # extra_args = args[2] + return executor(*args, **kwargs) + finally: + pass + + @classmethod + def execute(cls, model: io.Model.Type) -> io.NodeOutput: + model = model.clone() + + model.model_options["transformer_options"]["optimized_attention_override"] = cls.attention_override + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, "attention_override_test", cls.sampler_sampler_wrapper) + return io.NodeOutput(model) class WanExtension(ComfyExtension): @override @@ -1020,6 +1056,7 @@ class WanExtension(ComfyExtension): WanPhantomSubjectToVideo, WanSoundImageToVideo, Wan22ImageToVideoLatent, + AttentionOverrideTest, ] async def comfy_entrypoint() -> WanExtension: