Automatically use fp8 for diffusion model weights if:

Checkpoint contains weights in fp8.

There isn't enough memory to load the diffusion model in GPU vram.
This commit is contained in:
comfyanonymous
2024-08-03 13:45:19 -04:00
parent f123328b82
commit ba9095e5bd
4 changed files with 34 additions and 4 deletions

View File

@@ -527,6 +527,9 @@ def unet_inital_load_device(parameters, dtype):
else:
return cpu_dev
def maximum_vram_for_weights(device=None):
return (get_total_memory(device) * 0.8 - minimum_inference_memory())
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if args.bf16_unet:
return torch.bfloat16
@@ -536,6 +539,21 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
return torch.float8_e4m3fn
if args.fp8_e5m2_unet:
return torch.float8_e5m2
fp8_dtype = None
try:
for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if dtype in supported_dtypes:
fp8_dtype = dtype
break
except:
pass
if fp8_dtype is not None:
free_model_memory = maximum_vram_for_weights(device)
if model_params * 2 > free_model_memory:
return fp8_dtype
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes:
return torch.float16
@@ -871,7 +889,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
fp16_works = True
if fp16_works or manual_cast:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
free_model_memory = maximum_vram_for_weights(device)
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
@@ -920,7 +938,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
bf16_works = torch.cuda.is_bf16_supported()
if bf16_works or manual_cast:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
free_model_memory = maximum_vram_for_weights(device)
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True