Add some command line arguments to store text encoder weights in fp8.

Pytorch supports two variants of fp8:
--fp8_e4m3fn-text-enc (the one that seems to give better results)
--fp8_e5m2-text-enc
This commit is contained in:
comfyanonymous
2023-11-17 02:56:59 -05:00
parent 107e78b1cb
commit 0cf4e86939
3 changed files with 23 additions and 4 deletions

View File

@@ -482,6 +482,21 @@ def text_encoder_device():
else:
return torch.device("cpu")
def text_encoder_dtype(device=None):
if args.fp8_e4m3fn_text_enc:
return torch.float8_e4m3fn
elif args.fp8_e5m2_text_enc:
return torch.float8_e5m2
elif args.fp16_text_enc:
return torch.float16
elif args.fp32_text_enc:
return torch.float32
if should_use_fp16(device, prioritize_performance=False):
return torch.float16
else:
return torch.float32
def vae_device():
return get_torch_device()