Use fp16 if checkpoint weights are fp16 and the model supports it.

This commit is contained in:
comfyanonymous
2025-02-27 16:39:57 -05:00
parent f4dac8ab6f
commit 1804397952
3 changed files with 12 additions and 20 deletions

View File

@@ -674,7 +674,7 @@ def unet_inital_load_device(parameters, dtype):
def maximum_vram_for_weights(device=None):
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32], weight_dtype=None):
if model_params < 0:
model_params = 1000000000000000000000
if args.fp32_unet:
@@ -692,10 +692,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
fp8_dtype = None
try:
for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if dtype in supported_dtypes:
fp8_dtype = dtype
break
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
fp8_dtype = weight_dtype
except:
pass
@@ -707,7 +705,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
if model_params * 2 > free_model_memory:
return fp8_dtype
if PRIORITIZE_FP16:
if PRIORITIZE_FP16 or weight_dtype == torch.float16:
if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params):
return torch.float16