diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 804fd0df9..e26f66bb3 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -60,21 +60,6 @@ def get_attention_function(name: str, default: Any=...) -> Union[Callable, None] return default return REGISTERED_ATTENTION_FUNCTIONS[name] -def _register_core_attention_functions(): - """ - Register attention functions exposed by core ComfyUI. - """ - # NOTE: attention_basic is purposely not registered, as it should not be used - if SAGE_ATTENTION_IS_AVAILABLE: - register_attention_function("sage", attention_sage) - if FLASH_ATTENTION_IS_AVAILABLE: - register_attention_function("flash", attention_flash) - if model_management.xformers_enabled(): - register_attention_function("xformers", attention_xformers) - register_attention_function("pytorch", attention_pytorch) - register_attention_function("sub_quad", attention_sub_quad) - register_attention_function("split", attention_split) - from comfy.cli_args import args import comfy.ops ops = comfy.ops.disable_weight_init @@ -657,10 +642,22 @@ else: else: logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad -_register_core_attention_functions() optimized_attention_masked = optimized_attention + +# register core-supported attention functions +if SAGE_ATTENTION_IS_AVAILABLE: + register_attention_function("sage", attention_sage) +if FLASH_ATTENTION_IS_AVAILABLE: + register_attention_function("flash", attention_flash) +if model_management.xformers_enabled(): + register_attention_function("xformers", attention_xformers) +register_attention_function("pytorch", attention_pytorch) +register_attention_function("sub_quad", attention_sub_quad) +register_attention_function("split", attention_split) + + def optimized_attention_for_device(device, mask=False, small_input=False): if small_input: if model_management.pytorch_attention_enabled():