Use fp16 if checkpoint weights are fp16 and the model supports it.

This commit is contained in:
comfyanonymous
2025-02-27 16:39:57 -05:00
parent f4dac8ab6f
commit 1804397952
3 changed files with 12 additions and 20 deletions

View File

@@ -418,10 +418,7 @@ def controlnet_config(sd, model_options={}):
weight_dtype = comfy.utils.weight_dtype(sd)
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
@@ -689,10 +686,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
if supported_inference_dtypes is None:
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
load_device = comfy.model_management.get_torch_device()