Initialize text encoder to target dtype.

This commit is contained in:
comfyanonymous
2023-08-23 21:01:15 -04:00
parent f081017c1a
commit 00c0b2c507
5 changed files with 29 additions and 15 deletions

View File

@@ -43,7 +43,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None): # clip-vit-base-patch32
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.num_layers = 12
@@ -54,10 +54,12 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
config = CLIPTextConfig.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
with comfy.ops.use_comfy_ops():
with comfy.ops.use_comfy_ops(device, dtype):
with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)
if dtype is not None:
self.transformer.to(dtype)
self.max_length = max_length
if freeze:
self.freeze()