mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 12:37:01 +00:00
Made WAN attention receive transformer_options, test node added to wan to test out attention override later
This commit is contained in:
@@ -52,7 +52,7 @@ class WanSelfAttention(nn.Module):
|
|||||||
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, freqs):
|
def forward(self, x, freqs, transformer_options={}):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||||
@@ -75,6 +75,7 @@ class WanSelfAttention(nn.Module):
|
|||||||
k.view(b, s, n * d),
|
k.view(b, s, n * d),
|
||||||
v,
|
v,
|
||||||
heads=self.num_heads,
|
heads=self.num_heads,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
x = self.o(x)
|
x = self.o(x)
|
||||||
@@ -83,7 +84,7 @@ class WanSelfAttention(nn.Module):
|
|||||||
|
|
||||||
class WanT2VCrossAttention(WanSelfAttention):
|
class WanT2VCrossAttention(WanSelfAttention):
|
||||||
|
|
||||||
def forward(self, x, context, **kwargs):
|
def forward(self, x, context, transformer_options={}, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
@@ -95,7 +96,7 @@ class WanT2VCrossAttention(WanSelfAttention):
|
|||||||
v = self.v(context)
|
v = self.v(context)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
x = optimized_attention(q, k, v, heads=self.num_heads)
|
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
|
||||||
|
|
||||||
x = self.o(x)
|
x = self.o(x)
|
||||||
return x
|
return x
|
||||||
@@ -206,6 +207,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
freqs,
|
freqs,
|
||||||
context,
|
context,
|
||||||
context_img_len=257,
|
context_img_len=257,
|
||||||
|
transformer_options={},
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -224,12 +226,12 @@ class WanAttentionBlock(nn.Module):
|
|||||||
# self-attention
|
# self-attention
|
||||||
y = self.self_attn(
|
y = self.self_attn(
|
||||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||||
freqs)
|
freqs, transformer_options=transformer_options)
|
||||||
|
|
||||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||||
|
|
||||||
# cross-attention & ffn
|
# cross-attention & ffn
|
||||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||||
return x
|
return x
|
||||||
@@ -564,7 +566,7 @@ class WanModel(torch.nn.Module):
|
|||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||||
|
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
|
@@ -1003,6 +1003,42 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
|
|||||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||||
return io.NodeOutput(out_latent)
|
return io.NodeOutput(out_latent)
|
||||||
|
|
||||||
|
import comfy.patcher_extension
|
||||||
|
import comfy.ldm.modules.attention
|
||||||
|
class AttentionOverrideTest(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="AttentionOverrideTest",
|
||||||
|
category="devtools",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def attention_override(func, transformer_options, *args, **kwargs):
|
||||||
|
new_attention = comfy.ldm.modules.attention.attention_basic
|
||||||
|
return new_attention.__wrapped__(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sampler_sampler_wrapper(executor, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
# extra_args = args[2]
|
||||||
|
return executor(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
|
||||||
|
model = model.clone()
|
||||||
|
|
||||||
|
model.model_options["transformer_options"]["optimized_attention_override"] = cls.attention_override
|
||||||
|
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, "attention_override_test", cls.sampler_sampler_wrapper)
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
class WanExtension(ComfyExtension):
|
class WanExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
@@ -1020,6 +1056,7 @@ class WanExtension(ComfyExtension):
|
|||||||
WanPhantomSubjectToVideo,
|
WanPhantomSubjectToVideo,
|
||||||
WanSoundImageToVideo,
|
WanSoundImageToVideo,
|
||||||
Wan22ImageToVideoLatent,
|
Wan22ImageToVideoLatent,
|
||||||
|
AttentionOverrideTest,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def comfy_entrypoint() -> WanExtension:
|
async def comfy_entrypoint() -> WanExtension:
|
||||||
|
Reference in New Issue
Block a user