Make sure wrap_attn doesn't make itself recurse infinitely, attempt to load SageAttention and FlashAttention if not enabled so that they can be marked as available or not, create registry for available attention

This commit is contained in:
Jedrzej Kosinski
2025-08-28 18:53:20 -07:00
parent 669b9ef8e6
commit 51a30c2ad7

View File

@@ -8,7 +8,7 @@ import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional
from typing import Optional, Any, Callable, Union
import logging
import functools
@@ -21,23 +21,58 @@ if model_management.xformers_enabled():
import xformers
import xformers.ops
if model_management.sage_attention_enabled():
SAGE_ATTENTION_IS_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTENTION_IS_AVAILABLE = True
except ModuleNotFoundError as e:
if model_management.sage_attention_enabled():
if e.name == "sageattention":
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
else:
raise e
exit(-1)
if model_management.flash_attention_enabled():
FLASH_ATTENTION_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
FLASH_ATTENTION_IS_AVAILABLE = True
except ModuleNotFoundError:
if model_management.flash_attention_enabled():
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
REGISTERED_ATTENTION_FUNCTIONS = {}
def register_attention_function(name: str, func: Callable):
# avoid replacing existing functions
if name not in REGISTERED_ATTENTION_FUNCTIONS:
REGISTERED_ATTENTION_FUNCTIONS[name] = func
else:
logging.warning(f"Attention function {name} already registered, skipping registration.")
def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]:
if name not in REGISTERED_ATTENTION_FUNCTIONS:
if default is ...:
raise KeyError(f"Attention function {name} not found.")
else:
return default
return REGISTERED_ATTENTION_FUNCTIONS[name]
def _register_core_attention_functions():
"""
Register attention functions exposed by core ComfyUI.
"""
# NOTE: attention_basic is purposely not registered, as it is not used in code
if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage)
if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash)
if model_management.xformers_enabled():
register_attention_function("xformers", attention_xformers)
register_attention_function("pytorch", attention_pytorch)
register_attention_function("sub_quad", attention_sub_quad)
register_attention_function("split", attention_split)
from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init
@@ -137,6 +172,8 @@ def has_transformer_options_passed(frame):
def wrap_attn(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
remove_attn_wrapper_key = False
try:
if LOG_ATTN_CALLS:
continue_to_add = True
to_add = 1000
@@ -194,11 +231,17 @@ def wrap_attn(func):
finally:
# Important: break ref cycles so tensors aren't pinned
del frame
if "_inside_attn_wrapper" not in kwargs:
transformer_options = kwargs.get("transformer_options", None)
remove_attn_wrapper_key = True
kwargs["_inside_attn_wrapper"] = True
if transformer_options is not None:
if "optimized_attention_override" in transformer_options:
return transformer_options["optimized_attention_override"](func, transformer_options, *args, **kwargs)
return transformer_options["optimized_attention_override"](func, *args, **kwargs)
return func(*args, **kwargs)
finally:
if remove_attn_wrapper_key:
del kwargs["_inside_attn_wrapper"]
return wrapper
@wrap_attn
@@ -707,6 +750,7 @@ else:
else:
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
_register_core_attention_functions()
optimized_attention_masked = optimized_attention