Add optimized to get_attention_function

This commit is contained in:
Jedrzej Kosinski
2025-08-29 21:48:36 -07:00
parent d553073a1e
commit cb959f9669

View File

@@ -51,7 +51,9 @@ def register_attention_function(name: str, func: Callable):
logging.warning(f"Attention function {name} already registered, skipping registration.") logging.warning(f"Attention function {name} already registered, skipping registration.")
def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]: def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]:
if name not in REGISTERED_ATTENTION_FUNCTIONS: if name == "optimized":
return optimized_attention
elif name not in REGISTERED_ATTENTION_FUNCTIONS:
if default is ...: if default is ...:
raise KeyError(f"Attention function {name} not found.") raise KeyError(f"Attention function {name} not found.")
else: else:
@@ -62,7 +64,7 @@ def _register_core_attention_functions():
""" """
Register attention functions exposed by core ComfyUI. Register attention functions exposed by core ComfyUI.
""" """
# NOTE: attention_basic is purposely not registered, as it is not used in code # NOTE: attention_basic is purposely not registered, as it should not be used
if SAGE_ATTENTION_IS_AVAILABLE: if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage) register_attention_function("sage", attention_sage)
if FLASH_ATTENTION_IS_AVAILABLE: if FLASH_ATTENTION_IS_AVAILABLE: