Cleaner CLIP text encoder implementation.

Use a simple CLIP model implementation instead of the one from
transformers.

This will allow some interesting things that would too hackish to implement
using the transformers implementation.
This commit is contained in:
comfyanonymous
2023-12-06 15:55:09 -05:00
parent 2db86b4676
commit fbdb14d4c4
5 changed files with 172 additions and 49 deletions

View File

@@ -3,13 +3,13 @@ import torch
import os
class SDXLClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
if layer == "penultimate":
layer="hidden"
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
def load_sd(self, sd):
@@ -37,7 +37,7 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx):