Fix last commit not working on older pytorch. (#9346)

This commit is contained in:
comfyanonymous
2025-08-14 20:44:02 -07:00
committed by GitHub
parent f0d5d0111f
commit 4e5c230f6a

View File

@@ -32,7 +32,8 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
try:
if torch.cuda.is_available():
from torch.nn.attention import SDPBackend, sdpa_kernel
import inspect
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
SDPA_BACKEND_PRIORITY = [
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
@@ -42,10 +43,10 @@ try:
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
# Use this (rather than the decorator syntax) to eliminate graph
# break for pytorch < 2.9
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
else:
logging.warning("Torch version too old to set sdpa backend priority.")
except (ModuleNotFoundError, TypeError):
logging.warning("Could not set sdpa backend priority.")