Make sure wrap_attn doesn't make itself recurse infinitely, attempt to load SageAttention and FlashAttention if not enabled so that they can be marked as available or not, create registry for available attention

This commit is contained in:
Jedrzej Kosinski
2025-08-28 18:53:20 -07:00
parent 669b9ef8e6
commit 51a30c2ad7

View File

@@ -8,7 +8,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from typing import Optional from typing import Optional, Any, Callable, Union
import logging import logging
import functools import functools
@@ -21,23 +21,58 @@ if model_management.xformers_enabled():
import xformers import xformers
import xformers.ops import xformers.ops
if model_management.sage_attention_enabled(): SAGE_ATTENTION_IS_AVAILABLE = False
try: try:
from sageattention import sageattn from sageattention import sageattn
except ModuleNotFoundError as e: SAGE_ATTENTION_IS_AVAILABLE = True
except ModuleNotFoundError as e:
if model_management.sage_attention_enabled():
if e.name == "sageattention": if e.name == "sageattention":
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
else: else:
raise e raise e
exit(-1) exit(-1)
if model_management.flash_attention_enabled(): FLASH_ATTENTION_IS_AVAILABLE = False
try: try:
from flash_attn import flash_attn_func from flash_attn import flash_attn_func
except ModuleNotFoundError: FLASH_ATTENTION_IS_AVAILABLE = True
except ModuleNotFoundError:
if model_management.flash_attention_enabled():
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1) exit(-1)
REGISTERED_ATTENTION_FUNCTIONS = {}
def register_attention_function(name: str, func: Callable):
# avoid replacing existing functions
if name not in REGISTERED_ATTENTION_FUNCTIONS:
REGISTERED_ATTENTION_FUNCTIONS[name] = func
else:
logging.warning(f"Attention function {name} already registered, skipping registration.")
def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]:
if name not in REGISTERED_ATTENTION_FUNCTIONS:
if default is ...:
raise KeyError(f"Attention function {name} not found.")
else:
return default
return REGISTERED_ATTENTION_FUNCTIONS[name]
def _register_core_attention_functions():
"""
Register attention functions exposed by core ComfyUI.
"""
# NOTE: attention_basic is purposely not registered, as it is not used in code
if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage)
if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash)
if model_management.xformers_enabled():
register_attention_function("xformers", attention_xformers)
register_attention_function("pytorch", attention_pytorch)
register_attention_function("sub_quad", attention_sub_quad)
register_attention_function("split", attention_split)
from comfy.cli_args import args from comfy.cli_args import args
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@@ -137,68 +172,76 @@ def has_transformer_options_passed(frame):
def wrap_attn(func): def wrap_attn(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if LOG_ATTN_CALLS: remove_attn_wrapper_key = False
continue_to_add = True try:
to_add = 1000 if LOG_ATTN_CALLS:
logged_stack = [] continue_to_add = True
logged_stack_to_index = -1 to_add = 1000
logged_stack = []
logged_stack_to_index = -1
frame = inspect.currentframe() frame = inspect.currentframe()
try: try:
# skip wrapper, start at actual wrapped function # skip wrapper, start at actual wrapped function
frame = frame.f_back
while frame and continue_to_add and to_add > 0:
code = frame.f_code
filename = code.co_filename
function = code.co_name
lineno = frame.f_lineno
if function == "_calc_cond_batch_outer":
break
if 'venv' in filename:
frame = frame.f_back
continue
elif 'ComfyUI' not in filename:
frame = frame.f_back
continue
elif 'execution.py' in filename:
frame = frame.f_back
continue
elif 'patcher_extension.py' in filename:
frame = frame.f_back
continue
to_add -= 1
cls_name = get_class_from_frame(frame)
log_string = f"{filename}:{lineno}"
if cls_name:
log_string += f":{cls_name}.{function}"
else:
log_string += f":{function}"
if has_transformer_options_passed(frame):
log_string += ":✅"
if logged_stack_to_index == -1:
logged_stack_to_index = len(logged_stack)
else:
log_string += ":❌"
logged_stack.append(log_string)
# move up the stack
frame = frame.f_back frame = frame.f_back
LOG_CONTENTS["|".join(logged_stack)] = (logged_stack_to_index, logged_stack) while frame and continue_to_add and to_add > 0:
code = frame.f_code
filename = code.co_filename
function = code.co_name
lineno = frame.f_lineno
finally: if function == "_calc_cond_batch_outer":
# Important: break ref cycles so tensors aren't pinned break
del frame if 'venv' in filename:
transformer_options = kwargs.get("transformer_options", None) frame = frame.f_back
if transformer_options is not None: continue
if "optimized_attention_override" in transformer_options: elif 'ComfyUI' not in filename:
return transformer_options["optimized_attention_override"](func, transformer_options, *args, **kwargs) frame = frame.f_back
return func(*args, **kwargs) continue
elif 'execution.py' in filename:
frame = frame.f_back
continue
elif 'patcher_extension.py' in filename:
frame = frame.f_back
continue
to_add -= 1
cls_name = get_class_from_frame(frame)
log_string = f"{filename}:{lineno}"
if cls_name:
log_string += f":{cls_name}.{function}"
else:
log_string += f":{function}"
if has_transformer_options_passed(frame):
log_string += ":✅"
if logged_stack_to_index == -1:
logged_stack_to_index = len(logged_stack)
else:
log_string += ":❌"
logged_stack.append(log_string)
# move up the stack
frame = frame.f_back
LOG_CONTENTS["|".join(logged_stack)] = (logged_stack_to_index, logged_stack)
finally:
# Important: break ref cycles so tensors aren't pinned
del frame
if "_inside_attn_wrapper" not in kwargs:
transformer_options = kwargs.get("transformer_options", None)
remove_attn_wrapper_key = True
kwargs["_inside_attn_wrapper"] = True
if transformer_options is not None:
if "optimized_attention_override" in transformer_options:
return transformer_options["optimized_attention_override"](func, *args, **kwargs)
return func(*args, **kwargs)
finally:
if remove_attn_wrapper_key:
del kwargs["_inside_attn_wrapper"]
return wrapper return wrapper
@wrap_attn @wrap_attn
@@ -707,6 +750,7 @@ else:
else: else:
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention") logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad optimized_attention = attention_sub_quad
_register_core_attention_functions()
optimized_attention_masked = optimized_attention optimized_attention_masked = optimized_attention