Load the SD3 T5xxl model in the same dtype stored in the checkpoint.

This commit is contained in:
comfyanonymous
2024-06-11 17:03:26 -04:00
parent 5889b7ca0a
commit 0e49211a11
6 changed files with 49 additions and 6 deletions

View File

@@ -98,13 +98,19 @@ class CLIP:
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device
params['dtype'] = model_management.text_encoder_dtype(load_device)
dtype = model_management.text_encoder_dtype(load_device)
params['dtype'] = dtype
self.cond_stage_model = clip(**(params))
for dt in self.cond_stage_model.dtypes:
if not model_management.supports_cast(load_device, dt):
load_device = offload_device
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.layer_idx = None
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
def clone(self):
n = CLIP(no_init=True)