From 929e266f3e298478a1433fcff8b0209e52790068 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 17 Feb 2024 08:13:17 -0500 Subject: [PATCH] Manual cast for bf16 on older GPUs. --- comfy/model_management.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index f0f4ebf58..681208ea0 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -499,7 +499,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor if should_use_fp16(device=device, model_params=model_params, manual_cast=True): if torch.float16 in supported_dtypes: return torch.float16 - if should_use_bf16(device): + if should_use_bf16(device, model_params=model_params, manual_cast=True): if torch.bfloat16 in supported_dtypes: return torch.bfloat16 return torch.float32 @@ -771,10 +771,24 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma return True -def should_use_bf16(device=None): +def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): + if device is not None: + if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow + return False + + if device is not None: #TODO not sure about mps bf16 support + if is_device_mps(device): + return False + if FORCE_FP32: return False + if directml_enabled: + return False + + if cpu_mode() or mps_mode(): + return False + if is_intel_xpu(): return True @@ -785,6 +799,13 @@ def should_use_bf16(device=None): if props.major >= 8: return True + bf16_works = torch.cuda.is_bf16_supported() + + if bf16_works or manual_cast: + free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + if (not prioritize_performance) or model_params * 4 > free_model_memory: + return True + return False def soft_empty_cache(force=False):