mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 03:58:22 +00:00
UNET weights can now be stored in fp8.
--fp8_e4m3fn-unet and --fp8_e5m2-unet are the two different formats supported by pytorch.
This commit is contained in:
@@ -459,6 +459,10 @@ def unet_inital_load_device(parameters, dtype):
|
||||
def unet_dtype(device=None, model_params=0):
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if args.fp8_e4m3fn_unet:
|
||||
return torch.float8_e4m3fn
|
||||
if args.fp8_e5m2_unet:
|
||||
return torch.float8_e5m2
|
||||
if should_use_fp16(device=device, model_params=model_params):
|
||||
return torch.float16
|
||||
return torch.float32
|
||||
@@ -515,6 +519,17 @@ def get_autocast_device(dev):
|
||||
return dev.type
|
||||
return "cuda"
|
||||
|
||||
def supports_dtype(device, dtype): #TODO
|
||||
if dtype == torch.float32:
|
||||
return True
|
||||
if torch.device("cpu") == device:
|
||||
return False
|
||||
if dtype == torch.float16:
|
||||
return True
|
||||
if dtype == torch.bfloat16:
|
||||
return True
|
||||
return False
|
||||
|
||||
def cast_to_device(tensor, device, dtype, copy=False):
|
||||
device_supports_cast = False
|
||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||
|
Reference in New Issue
Block a user