Made WAN attention receive transformer_options, test node added to wan to test out attention override later

This commit is contained in:
Jedrzej Kosinski
2025-08-27 17:56:21 -07:00
parent 29b7990dc2
commit dd21b4aa51
2 changed files with 45 additions and 6 deletions

View File

@@ -52,7 +52,7 @@ class WanSelfAttention(nn.Module):
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, freqs):
def forward(self, x, freqs, transformer_options={}):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
@@ -75,6 +75,7 @@ class WanSelfAttention(nn.Module):
k.view(b, s, n * d),
v,
heads=self.num_heads,
transformer_options=transformer_options,
)
x = self.o(x)
@@ -83,7 +84,7 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context, **kwargs):
def forward(self, x, context, transformer_options={}, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@@ -95,7 +96,7 @@ class WanT2VCrossAttention(WanSelfAttention):
v = self.v(context)
# compute attention
x = optimized_attention(q, k, v, heads=self.num_heads)
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
x = self.o(x)
return x
@@ -206,6 +207,7 @@ class WanAttentionBlock(nn.Module):
freqs,
context,
context_img_len=257,
transformer_options={},
):
r"""
Args:
@@ -224,12 +226,12 @@ class WanAttentionBlock(nn.Module):
# self-attention
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs)
freqs, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
@@ -564,7 +566,7 @@ class WanModel(torch.nn.Module):
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# head
x = self.head(x, e)

View File

@@ -1003,6 +1003,42 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return io.NodeOutput(out_latent)
import comfy.patcher_extension
import comfy.ldm.modules.attention
class AttentionOverrideTest(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="AttentionOverrideTest",
category="devtools",
inputs=[
io.Model.Input("model"),
],
outputs=[
io.Model.Output(),
],
)
@staticmethod
def attention_override(func, transformer_options, *args, **kwargs):
new_attention = comfy.ldm.modules.attention.attention_basic
return new_attention.__wrapped__(*args, **kwargs)
@staticmethod
def sampler_sampler_wrapper(executor, *args, **kwargs):
try:
# extra_args = args[2]
return executor(*args, **kwargs)
finally:
pass
@classmethod
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
model = model.clone()
model.model_options["transformer_options"]["optimized_attention_override"] = cls.attention_override
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, "attention_override_test", cls.sampler_sampler_wrapper)
return io.NodeOutput(model)
class WanExtension(ComfyExtension):
@override
@@ -1020,6 +1056,7 @@ class WanExtension(ComfyExtension):
WanPhantomSubjectToVideo,
WanSoundImageToVideo,
Wan22ImageToVideoLatent,
AttentionOverrideTest,
]
async def comfy_entrypoint() -> WanExtension: