Save memory by storing text encoder weights in fp16 in most situations.

Do inference in fp32 to make sure quality stays the exact same.
This commit is contained in:
comfyanonymous
2023-08-23 01:07:57 -04:00
parent d7b3b0f8c1
commit f081017c1a
4 changed files with 5 additions and 9 deletions

View File

@@ -137,9 +137,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if backup_embeds.weight.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = contextlib.nullcontext
precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(model_management.get_autocast_device(device)):
with precision_scope(model_management.get_autocast_device(device), torch.float32):
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
self.transformer.set_input_embeddings(backup_embeds)