Make highvram and normalvram shift the text encoders to vram and back.

This is faster on big text encoder models than running it on the CPU.
This commit is contained in:
comfyanonymous
2023-07-01 12:37:23 -04:00
parent fa1959e3ef
commit 97ee230682
3 changed files with 46 additions and 20 deletions

View File

@@ -526,9 +526,10 @@ class CLIP:
tokenizer = target.tokenizer
self.device = model_management.text_encoder_device()
params["device"] = self.device
self.cond_stage_model = clip(**(params))
self.cond_stage_model = self.cond_stage_model.to(self.device)
if model_management.should_use_fp16(self.device):
self.cond_stage_model.half()
self.cond_stage_model = self.cond_stage_model.to(model_management.text_encoder_offload_device())
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = ModelPatcher(self.cond_stage_model)
@@ -559,11 +560,14 @@ class CLIP:
if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx)
try:
self.cond_stage_model.to(self.device)
self.patch_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
self.unpatch_model()
self.cond_stage_model.to(model_management.text_encoder_offload_device())
except Exception as e:
self.unpatch_model()
self.cond_stage_model.to(model_management.text_encoder_offload_device())
raise e
cond_out = cond