mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 04:27:21 +00:00
Make sure wrap_attn doesn't make itself recurse infinitely, attempt to load SageAttention and FlashAttention if not enabled so that they can be marked as available or not, create registry for available attention
This commit is contained in:
@@ -8,7 +8,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from typing import Optional
|
||||
from typing import Optional, Any, Callable, Union
|
||||
import logging
|
||||
import functools
|
||||
|
||||
@@ -21,23 +21,58 @@ if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
except ModuleNotFoundError as e:
|
||||
SAGE_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGE_ATTENTION_IS_AVAILABLE = True
|
||||
except ModuleNotFoundError as e:
|
||||
if model_management.sage_attention_enabled():
|
||||
if e.name == "sageattention":
|
||||
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
||||
else:
|
||||
raise e
|
||||
exit(-1)
|
||||
|
||||
if model_management.flash_attention_enabled():
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
FLASH_ATTENTION_IS_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
if model_management.flash_attention_enabled():
|
||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||
exit(-1)
|
||||
|
||||
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||
def register_attention_function(name: str, func: Callable):
|
||||
# avoid replacing existing functions
|
||||
if name not in REGISTERED_ATTENTION_FUNCTIONS:
|
||||
REGISTERED_ATTENTION_FUNCTIONS[name] = func
|
||||
else:
|
||||
logging.warning(f"Attention function {name} already registered, skipping registration.")
|
||||
|
||||
def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]:
|
||||
if name not in REGISTERED_ATTENTION_FUNCTIONS:
|
||||
if default is ...:
|
||||
raise KeyError(f"Attention function {name} not found.")
|
||||
else:
|
||||
return default
|
||||
return REGISTERED_ATTENTION_FUNCTIONS[name]
|
||||
|
||||
def _register_core_attention_functions():
|
||||
"""
|
||||
Register attention functions exposed by core ComfyUI.
|
||||
"""
|
||||
# NOTE: attention_basic is purposely not registered, as it is not used in code
|
||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("sage", attention_sage)
|
||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("flash", attention_flash)
|
||||
if model_management.xformers_enabled():
|
||||
register_attention_function("xformers", attention_xformers)
|
||||
register_attention_function("pytorch", attention_pytorch)
|
||||
register_attention_function("sub_quad", attention_sub_quad)
|
||||
register_attention_function("split", attention_split)
|
||||
|
||||
from comfy.cli_args import args
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@@ -137,68 +172,76 @@ def has_transformer_options_passed(frame):
|
||||
def wrap_attn(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if LOG_ATTN_CALLS:
|
||||
continue_to_add = True
|
||||
to_add = 1000
|
||||
logged_stack = []
|
||||
logged_stack_to_index = -1
|
||||
remove_attn_wrapper_key = False
|
||||
try:
|
||||
if LOG_ATTN_CALLS:
|
||||
continue_to_add = True
|
||||
to_add = 1000
|
||||
logged_stack = []
|
||||
logged_stack_to_index = -1
|
||||
|
||||
frame = inspect.currentframe()
|
||||
try:
|
||||
# skip wrapper, start at actual wrapped function
|
||||
frame = frame.f_back
|
||||
|
||||
while frame and continue_to_add and to_add > 0:
|
||||
code = frame.f_code
|
||||
filename = code.co_filename
|
||||
function = code.co_name
|
||||
lineno = frame.f_lineno
|
||||
|
||||
if function == "_calc_cond_batch_outer":
|
||||
break
|
||||
if 'venv' in filename:
|
||||
frame = frame.f_back
|
||||
continue
|
||||
elif 'ComfyUI' not in filename:
|
||||
frame = frame.f_back
|
||||
continue
|
||||
elif 'execution.py' in filename:
|
||||
frame = frame.f_back
|
||||
continue
|
||||
elif 'patcher_extension.py' in filename:
|
||||
frame = frame.f_back
|
||||
continue
|
||||
|
||||
to_add -= 1
|
||||
cls_name = get_class_from_frame(frame)
|
||||
log_string = f"{filename}:{lineno}"
|
||||
if cls_name:
|
||||
log_string += f":{cls_name}.{function}"
|
||||
else:
|
||||
log_string += f":{function}"
|
||||
|
||||
if has_transformer_options_passed(frame):
|
||||
log_string += ":✅"
|
||||
if logged_stack_to_index == -1:
|
||||
logged_stack_to_index = len(logged_stack)
|
||||
else:
|
||||
log_string += ":❌"
|
||||
|
||||
logged_stack.append(log_string)
|
||||
|
||||
# move up the stack
|
||||
frame = inspect.currentframe()
|
||||
try:
|
||||
# skip wrapper, start at actual wrapped function
|
||||
frame = frame.f_back
|
||||
|
||||
LOG_CONTENTS["|".join(logged_stack)] = (logged_stack_to_index, logged_stack)
|
||||
while frame and continue_to_add and to_add > 0:
|
||||
code = frame.f_code
|
||||
filename = code.co_filename
|
||||
function = code.co_name
|
||||
lineno = frame.f_lineno
|
||||
|
||||
finally:
|
||||
# Important: break ref cycles so tensors aren't pinned
|
||||
del frame
|
||||
transformer_options = kwargs.get("transformer_options", None)
|
||||
if transformer_options is not None:
|
||||
if "optimized_attention_override" in transformer_options:
|
||||
return transformer_options["optimized_attention_override"](func, transformer_options, *args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
if function == "_calc_cond_batch_outer":
|
||||
break
|
||||
if 'venv' in filename:
|
||||
frame = frame.f_back
|
||||
continue
|
||||
elif 'ComfyUI' not in filename:
|
||||
frame = frame.f_back
|
||||
continue
|
||||
elif 'execution.py' in filename:
|
||||
frame = frame.f_back
|
||||
continue
|
||||
elif 'patcher_extension.py' in filename:
|
||||
frame = frame.f_back
|
||||
continue
|
||||
|
||||
to_add -= 1
|
||||
cls_name = get_class_from_frame(frame)
|
||||
log_string = f"{filename}:{lineno}"
|
||||
if cls_name:
|
||||
log_string += f":{cls_name}.{function}"
|
||||
else:
|
||||
log_string += f":{function}"
|
||||
|
||||
if has_transformer_options_passed(frame):
|
||||
log_string += ":✅"
|
||||
if logged_stack_to_index == -1:
|
||||
logged_stack_to_index = len(logged_stack)
|
||||
else:
|
||||
log_string += ":❌"
|
||||
|
||||
logged_stack.append(log_string)
|
||||
|
||||
# move up the stack
|
||||
frame = frame.f_back
|
||||
|
||||
LOG_CONTENTS["|".join(logged_stack)] = (logged_stack_to_index, logged_stack)
|
||||
|
||||
finally:
|
||||
# Important: break ref cycles so tensors aren't pinned
|
||||
del frame
|
||||
if "_inside_attn_wrapper" not in kwargs:
|
||||
transformer_options = kwargs.get("transformer_options", None)
|
||||
remove_attn_wrapper_key = True
|
||||
kwargs["_inside_attn_wrapper"] = True
|
||||
if transformer_options is not None:
|
||||
if "optimized_attention_override" in transformer_options:
|
||||
return transformer_options["optimized_attention_override"](func, *args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
if remove_attn_wrapper_key:
|
||||
del kwargs["_inside_attn_wrapper"]
|
||||
return wrapper
|
||||
|
||||
@wrap_attn
|
||||
@@ -707,6 +750,7 @@ else:
|
||||
else:
|
||||
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
optimized_attention = attention_sub_quad
|
||||
_register_core_attention_functions()
|
||||
|
||||
optimized_attention_masked = optimized_attention
|
||||
|
||||
|
Reference in New Issue
Block a user