diff --git a/comfy/model_management.py b/comfy/model_management.py index 17516b6ed..cb6580f73 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -22,6 +22,7 @@ from enum import Enum from comfy.cli_args import args, PerformanceFeature import torch import sys +import importlib import platform import weakref import gc @@ -336,12 +337,13 @@ try: logging.info("AMD arch: {}".format(arch)) logging.info("ROCm version: {}".format(rocm_version)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much - if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 - ENABLE_PYTORCH_ATTENTION = True -# if torch_version_numeric >= (2, 8): -# if any((a in arch) for a in ["gfx1201"]): -# ENABLE_PYTORCH_ATTENTION = True + if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not. + if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 + ENABLE_PYTORCH_ATTENTION = True +# if torch_version_numeric >= (2, 8): +# if any((a in arch) for a in ["gfx1201"]): +# ENABLE_PYTORCH_ATTENTION = True if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches SUPPORT_FP8_OPS = True