mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 12:37:01 +00:00
Fix last commit not working on older pytorch. (#9346)
This commit is contained in:
25
comfy/ops.py
25
comfy/ops.py
@@ -32,20 +32,21 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
|||||||
try:
|
try:
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||||
|
import inspect
|
||||||
|
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||||
|
SDPA_BACKEND_PRIORITY = [
|
||||||
|
SDPBackend.FLASH_ATTENTION,
|
||||||
|
SDPBackend.EFFICIENT_ATTENTION,
|
||||||
|
SDPBackend.MATH,
|
||||||
|
]
|
||||||
|
|
||||||
SDPA_BACKEND_PRIORITY = [
|
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
||||||
SDPBackend.FLASH_ATTENTION,
|
|
||||||
SDPBackend.EFFICIENT_ATTENTION,
|
|
||||||
SDPBackend.MATH,
|
|
||||||
]
|
|
||||||
|
|
||||||
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||||
|
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
||||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||||
# Use this (rather than the decorator syntax) to eliminate graph
|
else:
|
||||||
# break for pytorch < 2.9
|
logging.warning("Torch version too old to set sdpa backend priority.")
|
||||||
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
|
||||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
|
||||||
except (ModuleNotFoundError, TypeError):
|
except (ModuleNotFoundError, TypeError):
|
||||||
logging.warning("Could not set sdpa backend priority.")
|
logging.warning("Could not set sdpa backend priority.")
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user