Avoid torch compile graphbreak for older pytorch versions (#9344)

Turns out torch.compile has some gaps in context manager decorator
syntax support. I've sent patches to fix that in PyTorch, but it won't
be available for all the folks running older versions of PyTorch, hence
this trivial patch.
This commit is contained in:
Xiangxi Guo (Ryan)
2025-08-14 20:41:37 -07:00
committed by GitHub
parent ad19a069f6
commit f0d5d0111f

View File

@@ -41,9 +41,11 @@ try:
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
# Use this (rather than the decorator syntax) to eliminate graph
# break for pytorch < 2.9
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
except (ModuleNotFoundError, TypeError):
logging.warning("Could not set sdpa backend priority.")