Make last PR not crash comfy on old pytorch. (#9324)

This commit is contained in:
comfyanonymous 2025-08-13 12:12:41 -07:00 committed by GitHub
parent 3da5a07510
commit 9df8792d4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 27 additions and 17 deletions

View File

@ -178,7 +178,7 @@ class FourierEmbedder(nn.Module):
class CrossAttentionProcessor:
def __call__(self, attn, q, k, v):
out = ops.scaled_dot_product_attention(q, k, v)
out = comfy.ops.scaled_dot_product_attention(q, k, v)
return out

View File

@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
mask = mask.unsqueeze(1)
if SDP_BATCH_LIMIT >= b:
out = ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.shape[0] > 1:
m = mask[i : i + SDP_BATCH_LIMIT]
out[i : i + SDP_BATCH_LIMIT] = ops.scaled_dot_product_attention(
out[i : i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention(
q[i : i + SDP_BATCH_LIMIT],
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],

View File

@ -285,7 +285,7 @@ def pytorch_attention(q, k, v):
)
try:
out = ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")

View File

@ -23,18 +23,32 @@ from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
import contextlib
from torch.nn.attention import SDPBackend, sdpa_kernel
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
try:
if torch.cuda.is_available():
from torch.nn.attention import SDPBackend, sdpa_kernel
SDPA_BACKEND_PRIORITY = [
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
]
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
except (ModuleNotFoundError, TypeError):
logging.warning("Could not set sdpa backend priority.")
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
SDPA_BACKEND_PRIORITY = [
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
]
if torch.cuda.is_available():
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@ -258,10 +272,6 @@ class disable_weight_init:
else:
raise ValueError(f"unsupported dimensions: {dims}")
@staticmethod
@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
class manual_cast(disable_weight_init):
class Linear(disable_weight_init.Linear):