Make sure wrap_attn doesn't make itself recurse infinitely, attempt to load SageAttention and FlashAttention if not enabled so that they can be marked as available or not, create registry for available attention

This commit is contained in:
Jedrzej Kosinski
2025-08-28 18:53:20 -07:00
parent 669b9ef8e6
commit 51a30c2ad7

View File

@@ -8,7 +8,7 @@ import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional
from typing import Optional, Any, Callable, Union
import logging
import functools
@@ -21,23 +21,58 @@ if model_management.xformers_enabled():
import xformers
import xformers.ops
if model_management.sage_attention_enabled():
try:
from sageattention import sageattn
except ModuleNotFoundError as e:
SAGE_ATTENTION_IS_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTENTION_IS_AVAILABLE = True
except ModuleNotFoundError as e:
if model_management.sage_attention_enabled():
if e.name == "sageattention":
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
else:
raise e
exit(-1)
if model_management.flash_attention_enabled():
try:
from flash_attn import flash_attn_func
except ModuleNotFoundError:
FLASH_ATTENTION_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
FLASH_ATTENTION_IS_AVAILABLE = True
except ModuleNotFoundError:
if model_management.flash_attention_enabled():
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
REGISTERED_ATTENTION_FUNCTIONS = {}
def register_attention_function(name: str, func: Callable):
# avoid replacing existing functions
if name not in REGISTERED_ATTENTION_FUNCTIONS:
REGISTERED_ATTENTION_FUNCTIONS[name] = func
else:
logging.warning(f"Attention function {name} already registered, skipping registration.")
def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]:
if name not in REGISTERED_ATTENTION_FUNCTIONS:
if default is ...:
raise KeyError(f"Attention function {name} not found.")
else:
return default
return REGISTERED_ATTENTION_FUNCTIONS[name]
def _register_core_attention_functions():
"""
Register attention functions exposed by core ComfyUI.
"""
# NOTE: attention_basic is purposely not registered, as it is not used in code
if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage)
if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash)
if model_management.xformers_enabled():
register_attention_function("xformers", attention_xformers)
register_attention_function("pytorch", attention_pytorch)
register_attention_function("sub_quad", attention_sub_quad)
register_attention_function("split", attention_split)
from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init
@@ -137,68 +172,76 @@ def has_transformer_options_passed(frame):
def wrap_attn(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if LOG_ATTN_CALLS:
continue_to_add = True
to_add = 1000
logged_stack = []
logged_stack_to_index = -1
remove_attn_wrapper_key = False
try:
if LOG_ATTN_CALLS:
continue_to_add = True
to_add = 1000
logged_stack = []
logged_stack_to_index = -1
frame = inspect.currentframe()
try:
# skip wrapper, start at actual wrapped function
frame = frame.f_back
while frame and continue_to_add and to_add > 0:
code = frame.f_code
filename = code.co_filename
function = code.co_name
lineno = frame.f_lineno
if function == "_calc_cond_batch_outer":
break
if 'venv' in filename:
frame = frame.f_back
continue
elif 'ComfyUI' not in filename:
frame = frame.f_back
continue
elif 'execution.py' in filename:
frame = frame.f_back
continue
elif 'patcher_extension.py' in filename:
frame = frame.f_back
continue
to_add -= 1
cls_name = get_class_from_frame(frame)
log_string = f"{filename}:{lineno}"
if cls_name:
log_string += f":{cls_name}.{function}"
else:
log_string += f":{function}"
if has_transformer_options_passed(frame):
log_string += ":✅"
if logged_stack_to_index == -1:
logged_stack_to_index = len(logged_stack)
else:
log_string += ":❌"
logged_stack.append(log_string)
# move up the stack
frame = inspect.currentframe()
try:
# skip wrapper, start at actual wrapped function
frame = frame.f_back
LOG_CONTENTS["|".join(logged_stack)] = (logged_stack_to_index, logged_stack)
while frame and continue_to_add and to_add > 0:
code = frame.f_code
filename = code.co_filename
function = code.co_name
lineno = frame.f_lineno
finally:
# Important: break ref cycles so tensors aren't pinned
del frame
transformer_options = kwargs.get("transformer_options", None)
if transformer_options is not None:
if "optimized_attention_override" in transformer_options:
return transformer_options["optimized_attention_override"](func, transformer_options, *args, **kwargs)
return func(*args, **kwargs)
if function == "_calc_cond_batch_outer":
break
if 'venv' in filename:
frame = frame.f_back
continue
elif 'ComfyUI' not in filename:
frame = frame.f_back
continue
elif 'execution.py' in filename:
frame = frame.f_back
continue
elif 'patcher_extension.py' in filename:
frame = frame.f_back
continue
to_add -= 1
cls_name = get_class_from_frame(frame)
log_string = f"{filename}:{lineno}"
if cls_name:
log_string += f":{cls_name}.{function}"
else:
log_string += f":{function}"
if has_transformer_options_passed(frame):
log_string += ":✅"
if logged_stack_to_index == -1:
logged_stack_to_index = len(logged_stack)
else:
log_string += ":❌"
logged_stack.append(log_string)
# move up the stack
frame = frame.f_back
LOG_CONTENTS["|".join(logged_stack)] = (logged_stack_to_index, logged_stack)
finally:
# Important: break ref cycles so tensors aren't pinned
del frame
if "_inside_attn_wrapper" not in kwargs:
transformer_options = kwargs.get("transformer_options", None)
remove_attn_wrapper_key = True
kwargs["_inside_attn_wrapper"] = True
if transformer_options is not None:
if "optimized_attention_override" in transformer_options:
return transformer_options["optimized_attention_override"](func, *args, **kwargs)
return func(*args, **kwargs)
finally:
if remove_attn_wrapper_key:
del kwargs["_inside_attn_wrapper"]
return wrapper
@wrap_attn
@@ -707,6 +750,7 @@ else:
else:
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
_register_core_attention_functions()
optimized_attention_masked = optimized_attention