diff --git a/comfy/model_management.py b/comfy/model_management.py index a107f0d49..187402748 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -295,6 +295,7 @@ except: pass +SUPPORT_FP8_OPS = args.supports_fp8_compute try: if is_amd(): try: @@ -308,6 +309,10 @@ try: if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches ENABLE_PYTORCH_ATTENTION = True + if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): + if any((a in arch) for a in ["gfx1201"]): # TODO: more arches + SUPPORT_FP8_OPS = True + except: pass @@ -1262,7 +1267,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False def supports_fp8_compute(device=None): - if args.supports_fp8_compute: + if SUPPORT_FP8_OPS: return True if not is_nvidia():