Mochi VAE encoder.

This commit is contained in:
comfyanonymous
2024-11-01 17:33:09 -04:00
parent cc9cf6d1bd
commit fabf449feb
3 changed files with 283 additions and 28 deletions

View File

@@ -393,6 +393,13 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
return out
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
SDP_BATCH_LIMIT = 2**15
else:
#TODO: other GPUs ?
SDP_BATCH_LIMIT = 2**31
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
@@ -404,10 +411,15 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
if SDP_BATCH_LIMIT >= q.shape[0]:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, q.shape[0], SDP_BATCH_LIMIT):
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out