From 3da5a07510794c37d437cbea1d94065bb0aa8ebc Mon Sep 17 00:00:00 2001 From: contentis Date: Wed, 13 Aug 2025 20:53:27 +0200 Subject: [PATCH] SDPA backend priority (#9299) --- comfy/ldm/hunyuan3d/vae.py | 2 +- comfy/ldm/modules/attention.py | 4 ++-- comfy/ldm/modules/diffusionmodules/model.py | 2 +- comfy/ops.py | 13 +++++++++++++ 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py index 5eb2c6548..bea6090a2 100644 --- a/comfy/ldm/hunyuan3d/vae.py +++ b/comfy/ldm/hunyuan3d/vae.py @@ -178,7 +178,7 @@ class FourierEmbedder(nn.Module): class CrossAttentionProcessor: def __call__(self, attn, q, k, v): - out = F.scaled_dot_product_attention(q, k, v) + out = ops.scaled_dot_product_attention(q, k, v) return out diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 35d2270ee..19c3c7af1 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha mask = mask.unsqueeze(1) if SDP_BATCH_LIMIT >= b: - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if mask.shape[0] > 1: m = mask[i : i + SDP_BATCH_LIMIT] - out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention( + out[i : i + SDP_BATCH_LIMIT] = ops.scaled_dot_product_attention( q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 5c0373b74..79160412f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -285,7 +285,7 @@ def pytorch_attention(q, k, v): ) try: - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) except model_management.OOM_EXCEPTION: logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") diff --git a/comfy/ops.py b/comfy/ops.py index 2cc9bbc27..8b7b662b6 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -23,9 +23,18 @@ from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm import contextlib +from torch.nn.attention import SDPBackend, sdpa_kernel cast_to = comfy.model_management.cast_to #TODO: remove once no more references +SDPA_BACKEND_PRIORITY = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, +] +if torch.cuda.is_available(): + SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) + def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) @@ -249,6 +258,10 @@ class disable_weight_init: else: raise ValueError(f"unsupported dimensions: {dims}") + @staticmethod + @sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True) + def scaled_dot_product_attention(q, k, v, *args, **kwargs): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear):