mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 12:37:01 +00:00
Fix CLIPSetLastLayer not reverting when removed.
This commit is contained in:
@@ -46,12 +46,14 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.num_layers = 12
|
||||
if textmodel_path is not None:
|
||||
self.transformer = CLIPTextModel.from_pretrained(textmodel_path)
|
||||
else:
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
with comfy.ops.use_comfy_ops():
|
||||
with modeling_utils.no_init_weights():
|
||||
self.transformer = CLIPTextModel(config)
|
||||
@@ -66,8 +68,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.layer_norm_hidden_state = True
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) <= 12
|
||||
assert abs(layer_idx) <= self.num_layers
|
||||
self.clip_layer(layer_idx)
|
||||
self.layer_default = (self.layer, self.layer_idx)
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
@@ -76,12 +79,16 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
param.requires_grad = False
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
if abs(layer_idx) >= 12:
|
||||
if abs(layer_idx) >= self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
self.layer = "hidden"
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def reset_clip_layer(self):
|
||||
self.layer = self.layer_default[0]
|
||||
self.layer_idx = self.layer_default[1]
|
||||
|
||||
def set_up_textual_embeddings(self, tokens, current_embeds):
|
||||
out_tokens = []
|
||||
next_new_token = token_dict_size = current_embeds.weight.shape[0]
|
||||
|
Reference in New Issue
Block a user