Fix mask issue in some attention functions.

This commit is contained in:
comfyanonymous
2024-11-22 02:10:09 -05:00
parent 8f0009aad0
commit 2fd9c1308a
2 changed files with 6 additions and 1 deletions

View File

@@ -299,7 +299,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
if len(mask.shape) == 2:
s1 += mask[i:end]
else:
s1 += mask[:, i:end]
if mask.shape[1] == 1:
s1 += mask
else:
s1 += mask[:, i:end]
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1