mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 12:37:01 +00:00
SDPA backend priority (#9299)
This commit is contained in:
13
comfy/ops.py
13
comfy/ops.py
@@ -23,9 +23,18 @@ from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import contextlib
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
SDPA_BACKEND_PRIORITY = [
|
||||
SDPBackend.FLASH_ATTENTION,
|
||||
SDPBackend.EFFICIENT_ATTENTION,
|
||||
SDPBackend.MATH,
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
@@ -249,6 +258,10 @@ class disable_weight_init:
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
@staticmethod
|
||||
@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
|
||||
class manual_cast(disable_weight_init):
|
||||
class Linear(disable_weight_init.Linear):
|
||||
|
Reference in New Issue
Block a user