mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 12:06:23 +00:00
Made WAN attention receive transformer_options, test node added to wan to test out attention override later
This commit is contained in:
@@ -52,7 +52,7 @@ class WanSelfAttention(nn.Module):
|
||||
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, freqs):
|
||||
def forward(self, x, freqs, transformer_options={}):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
@@ -75,6 +75,7 @@ class WanSelfAttention(nn.Module):
|
||||
k.view(b, s, n * d),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
x = self.o(x)
|
||||
@@ -83,7 +84,7 @@ class WanSelfAttention(nn.Module):
|
||||
|
||||
class WanT2VCrossAttention(WanSelfAttention):
|
||||
|
||||
def forward(self, x, context, **kwargs):
|
||||
def forward(self, x, context, transformer_options={}, **kwargs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
@@ -95,7 +96,7 @@ class WanT2VCrossAttention(WanSelfAttention):
|
||||
v = self.v(context)
|
||||
|
||||
# compute attention
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads)
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
@@ -206,6 +207,7 @@ class WanAttentionBlock(nn.Module):
|
||||
freqs,
|
||||
context,
|
||||
context_img_len=257,
|
||||
transformer_options={},
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
@@ -224,12 +226,12 @@ class WanAttentionBlock(nn.Module):
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs)
|
||||
freqs, transformer_options=transformer_options)
|
||||
|
||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
@@ -564,7 +566,7 @@ class WanModel(torch.nn.Module):
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
@@ -1003,6 +1003,42 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
|
||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||
return io.NodeOutput(out_latent)
|
||||
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.modules.attention
|
||||
class AttentionOverrideTest(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="AttentionOverrideTest",
|
||||
category="devtools",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def attention_override(func, transformer_options, *args, **kwargs):
|
||||
new_attention = comfy.ldm.modules.attention.attention_basic
|
||||
return new_attention.__wrapped__(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def sampler_sampler_wrapper(executor, *args, **kwargs):
|
||||
try:
|
||||
# extra_args = args[2]
|
||||
return executor(*args, **kwargs)
|
||||
finally:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
|
||||
model = model.clone()
|
||||
|
||||
model.model_options["transformer_options"]["optimized_attention_override"] = cls.attention_override
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, "attention_override_test", cls.sampler_sampler_wrapper)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
class WanExtension(ComfyExtension):
|
||||
@override
|
||||
@@ -1020,6 +1056,7 @@ class WanExtension(ComfyExtension):
|
||||
WanPhantomSubjectToVideo,
|
||||
WanSoundImageToVideo,
|
||||
Wan22ImageToVideoLatent,
|
||||
AttentionOverrideTest,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> WanExtension:
|
||||
|
Reference in New Issue
Block a user