diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 14d41a8c..a18a6929 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -239,6 +239,12 @@ def attention_split(q, k, v, heads, mask=None): else: s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale + if mask is not None: + if len(mask.shape) == 2: + s1 += mask[i:end] + else: + s1 += mask[:, i:end] + s2 = s1.softmax(dim=-1).to(v.dtype) del s1 first_op_done = True