diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index b3bb71734..f6013672b 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat -from typing import Optional +from typing import Optional, Any, Callable, Union import logging import functools @@ -21,23 +21,58 @@ if model_management.xformers_enabled(): import xformers import xformers.ops -if model_management.sage_attention_enabled(): - try: - from sageattention import sageattn - except ModuleNotFoundError as e: +SAGE_ATTENTION_IS_AVAILABLE = False +try: + from sageattention import sageattn + SAGE_ATTENTION_IS_AVAILABLE = True +except ModuleNotFoundError as e: + if model_management.sage_attention_enabled(): if e.name == "sageattention": logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") else: raise e exit(-1) -if model_management.flash_attention_enabled(): - try: - from flash_attn import flash_attn_func - except ModuleNotFoundError: +FLASH_ATTENTION_IS_AVAILABLE = False +try: + from flash_attn import flash_attn_func + FLASH_ATTENTION_IS_AVAILABLE = True +except ModuleNotFoundError: + if model_management.flash_attention_enabled(): logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") exit(-1) +REGISTERED_ATTENTION_FUNCTIONS = {} +def register_attention_function(name: str, func: Callable): + # avoid replacing existing functions + if name not in REGISTERED_ATTENTION_FUNCTIONS: + REGISTERED_ATTENTION_FUNCTIONS[name] = func + else: + logging.warning(f"Attention function {name} already registered, skipping registration.") + +def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]: + if name not in REGISTERED_ATTENTION_FUNCTIONS: + if default is ...: + raise KeyError(f"Attention function {name} not found.") + else: + return default + return REGISTERED_ATTENTION_FUNCTIONS[name] + +def _register_core_attention_functions(): + """ + Register attention functions exposed by core ComfyUI. + """ + # NOTE: attention_basic is purposely not registered, as it is not used in code + if SAGE_ATTENTION_IS_AVAILABLE: + register_attention_function("sage", attention_sage) + if FLASH_ATTENTION_IS_AVAILABLE: + register_attention_function("flash", attention_flash) + if model_management.xformers_enabled(): + register_attention_function("xformers", attention_xformers) + register_attention_function("pytorch", attention_pytorch) + register_attention_function("sub_quad", attention_sub_quad) + register_attention_function("split", attention_split) + from comfy.cli_args import args import comfy.ops ops = comfy.ops.disable_weight_init @@ -137,68 +172,76 @@ def has_transformer_options_passed(frame): def wrap_attn(func): @functools.wraps(func) def wrapper(*args, **kwargs): - if LOG_ATTN_CALLS: - continue_to_add = True - to_add = 1000 - logged_stack = [] - logged_stack_to_index = -1 + remove_attn_wrapper_key = False + try: + if LOG_ATTN_CALLS: + continue_to_add = True + to_add = 1000 + logged_stack = [] + logged_stack_to_index = -1 - frame = inspect.currentframe() - try: - # skip wrapper, start at actual wrapped function - frame = frame.f_back - - while frame and continue_to_add and to_add > 0: - code = frame.f_code - filename = code.co_filename - function = code.co_name - lineno = frame.f_lineno - - if function == "_calc_cond_batch_outer": - break - if 'venv' in filename: - frame = frame.f_back - continue - elif 'ComfyUI' not in filename: - frame = frame.f_back - continue - elif 'execution.py' in filename: - frame = frame.f_back - continue - elif 'patcher_extension.py' in filename: - frame = frame.f_back - continue - - to_add -= 1 - cls_name = get_class_from_frame(frame) - log_string = f"{filename}:{lineno}" - if cls_name: - log_string += f":{cls_name}.{function}" - else: - log_string += f":{function}" - - if has_transformer_options_passed(frame): - log_string += ":✅" - if logged_stack_to_index == -1: - logged_stack_to_index = len(logged_stack) - else: - log_string += ":❌" - - logged_stack.append(log_string) - - # move up the stack + frame = inspect.currentframe() + try: + # skip wrapper, start at actual wrapped function frame = frame.f_back - LOG_CONTENTS["|".join(logged_stack)] = (logged_stack_to_index, logged_stack) + while frame and continue_to_add and to_add > 0: + code = frame.f_code + filename = code.co_filename + function = code.co_name + lineno = frame.f_lineno - finally: - # Important: break ref cycles so tensors aren't pinned - del frame - transformer_options = kwargs.get("transformer_options", None) - if transformer_options is not None: - if "optimized_attention_override" in transformer_options: - return transformer_options["optimized_attention_override"](func, transformer_options, *args, **kwargs) - return func(*args, **kwargs) + if function == "_calc_cond_batch_outer": + break + if 'venv' in filename: + frame = frame.f_back + continue + elif 'ComfyUI' not in filename: + frame = frame.f_back + continue + elif 'execution.py' in filename: + frame = frame.f_back + continue + elif 'patcher_extension.py' in filename: + frame = frame.f_back + continue + + to_add -= 1 + cls_name = get_class_from_frame(frame) + log_string = f"{filename}:{lineno}" + if cls_name: + log_string += f":{cls_name}.{function}" + else: + log_string += f":{function}" + + if has_transformer_options_passed(frame): + log_string += ":✅" + if logged_stack_to_index == -1: + logged_stack_to_index = len(logged_stack) + else: + log_string += ":❌" + + logged_stack.append(log_string) + + # move up the stack + frame = frame.f_back + + LOG_CONTENTS["|".join(logged_stack)] = (logged_stack_to_index, logged_stack) + + finally: + # Important: break ref cycles so tensors aren't pinned + del frame + if "_inside_attn_wrapper" not in kwargs: + transformer_options = kwargs.get("transformer_options", None) + remove_attn_wrapper_key = True + kwargs["_inside_attn_wrapper"] = True + if transformer_options is not None: + if "optimized_attention_override" in transformer_options: + return transformer_options["optimized_attention_override"](func, *args, **kwargs) + return func(*args, **kwargs) + finally: + if remove_attn_wrapper_key: + del kwargs["_inside_attn_wrapper"] return wrapper @wrap_attn @@ -707,6 +750,7 @@ else: else: logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad +_register_core_attention_functions() optimized_attention_masked = optimized_attention