Add a --bf16-unet to test running the unet in bf16.

This commit is contained in:
comfyanonymous
2023-10-13 14:51:10 -04:00
parent 9a55dadb4c
commit fd4c5f07e7
2 changed files with 4 additions and 0 deletions

View File

@@ -449,6 +449,8 @@ def unet_inital_load_device(parameters, dtype):
return cpu_dev
def unet_dtype(device=None, model_params=0):
if args.bf16_unet:
return torch.bfloat16
if should_use_fp16(device=device, model_params=model_params):
return torch.float16
return torch.float32