UNET weights can now be stored in fp8.

--fp8_e4m3fn-unet and --fp8_e5m2-unet are the two different formats
supported by pytorch.
This commit is contained in:
comfyanonymous
2023-12-04 11:10:00 -05:00
parent af365e4dd1
commit 31b0f6f3d8
6 changed files with 47 additions and 10 deletions

View File

@@ -459,6 +459,10 @@ def unet_inital_load_device(parameters, dtype):
def unet_dtype(device=None, model_params=0):
if args.bf16_unet:
return torch.bfloat16
if args.fp8_e4m3fn_unet:
return torch.float8_e4m3fn
if args.fp8_e5m2_unet:
return torch.float8_e5m2
if should_use_fp16(device=device, model_params=model_params):
return torch.float16
return torch.float32
@@ -515,6 +519,17 @@ def get_autocast_device(dev):
return dev.type
return "cuda"
def supports_dtype(device, dtype): #TODO
if dtype == torch.float32:
return True
if torch.device("cpu") == device:
return False
if dtype == torch.float16:
return True
if dtype == torch.bfloat16:
return True
return False
def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: