mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 04:27:21 +00:00
Fix WanI2VCrossAttention so that it expects to receive transformer_options
This commit is contained in:
@@ -117,7 +117,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, context, context_img_len):
|
||||
def forward(self, x, context, context_img_len, transformer_options={}):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
@@ -132,9 +132,9 @@ class WanI2VCrossAttention(WanSelfAttention):
|
||||
v = self.v(context)
|
||||
k_img = self.norm_k_img(self.k_img(context_img))
|
||||
v_img = self.v_img(context_img)
|
||||
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
|
||||
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
|
||||
# compute attention
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads)
|
||||
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
|
||||
|
||||
# output
|
||||
x = x + img_x
|
||||
|
Reference in New Issue
Block a user