mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 12:06:23 +00:00
Fix last commit not working on older pytorch. (#9346)
This commit is contained in:
@@ -32,7 +32,8 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
import inspect
|
||||
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||
SDPA_BACKEND_PRIORITY = [
|
||||
SDPBackend.FLASH_ATTENTION,
|
||||
SDPBackend.EFFICIENT_ATTENTION,
|
||||
@@ -42,10 +43,10 @@ try:
|
||||
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
# Use this (rather than the decorator syntax) to eliminate graph
|
||||
# break for pytorch < 2.9
|
||||
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
else:
|
||||
logging.warning("Torch version too old to set sdpa backend priority.")
|
||||
except (ModuleNotFoundError, TypeError):
|
||||
logging.warning("Could not set sdpa backend priority.")
|
||||
|
||||
|
Reference in New Issue
Block a user