Use own clip vision model implementation.

This commit is contained in:
comfyanonymous
2023-12-09 11:56:31 -05:00
parent 97015b6b38
commit 174eba8e95
2 changed files with 81 additions and 22 deletions

View File

@@ -1,13 +1,20 @@
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils
from .utils import load_torch_file, transformers_convert, common_upscale
import os
import torch
import contextlib
import json
import comfy.ops
import comfy.model_patcher
import comfy.model_management
import comfy.utils
import comfy.clip_model
class Output:
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, item):
setattr(self, key, item)
def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
@@ -22,17 +29,16 @@ def clip_preprocess(image, size=224):
class ClipVisionModel():
def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config)
with open(json_config) as f:
config = json.load(f)
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = torch.float32
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
self.dtype = torch.float16
with comfy.ops.use_comfy_ops(offload_device, self.dtype):
with modeling_utils.no_init_weights():
self.model = CLIPVisionModelWithProjection(config)
self.model.to(self.dtype)
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops)
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
@@ -48,17 +54,12 @@ class ClipVisionModel():
precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32):
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
for k in outputs:
t = outputs[k]
if t is not None:
if k == 'hidden_states':
outputs["penultimate_hidden_states"] = t[-2].to(comfy.model_management.intermediate_device())
outputs["hidden_states"] = None
else:
outputs[k] = t.to(comfy.model_management.intermediate_device())
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output()
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
return outputs
def convert_to_transformers(sd, prefix):