Add some command line arguments to store text encoder weights in fp8.

Pytorch supports two variants of fp8:
--fp8_e4m3fn-text-enc (the one that seems to give better results)
--fp8_e5m2-text-enc
This commit is contained in:
comfyanonymous
2023-11-17 02:56:59 -05:00
parent 107e78b1cb
commit 0cf4e86939
3 changed files with 23 additions and 4 deletions

View File

@@ -95,10 +95,7 @@ class CLIP:
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device
if model_management.should_use_fp16(load_device, prioritize_performance=False):
params['dtype'] = torch.float16
else:
params['dtype'] = torch.float32
params['dtype'] = model_management.text_encoder_dtype(load_device)
self.cond_stage_model = clip(**(params))