Properly set if clip text pooled projection instead of using hack.

This commit is contained in:
comfyanonymous
2024-08-20 10:00:16 -04:00
parent 538cb068bc
commit 83dbac28eb
4 changed files with 8 additions and 6 deletions

View File

@@ -123,7 +123,6 @@ class CLIPTextModel(torch.nn.Module):
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
embed_dim = config_dict["hidden_size"]
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
def get_input_embeddings(self):