mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 04:55:53 +00:00
Use own clip vision model implementation.
This commit is contained in:
@@ -1,13 +1,20 @@
|
||||
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils
|
||||
from .utils import load_torch_file, transformers_convert, common_upscale
|
||||
import os
|
||||
import torch
|
||||
import contextlib
|
||||
import json
|
||||
|
||||
import comfy.ops
|
||||
import comfy.model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.clip_model
|
||||
|
||||
class Output:
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224):
|
||||
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
||||
@@ -22,17 +29,16 @@ def clip_preprocess(image, size=224):
|
||||
|
||||
class ClipVisionModel():
|
||||
def __init__(self, json_config):
|
||||
config = CLIPVisionConfig.from_json_file(json_config)
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = torch.float32
|
||||
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
|
||||
self.dtype = torch.float16
|
||||
|
||||
with comfy.ops.use_comfy_ops(offload_device, self.dtype):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.model = CLIPVisionModelWithProjection(config)
|
||||
self.model.to(self.dtype)
|
||||
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops)
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
def load_sd(self, sd):
|
||||
@@ -48,17 +54,12 @@ class ClipVisionModel():
|
||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
||||
|
||||
with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32):
|
||||
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
|
||||
|
||||
for k in outputs:
|
||||
t = outputs[k]
|
||||
if t is not None:
|
||||
if k == 'hidden_states':
|
||||
outputs["penultimate_hidden_states"] = t[-2].to(comfy.model_management.intermediate_device())
|
||||
outputs["hidden_states"] = None
|
||||
else:
|
||||
outputs[k] = t.to(comfy.model_management.intermediate_device())
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
|
||||
outputs = Output()
|
||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||
return outputs
|
||||
|
||||
def convert_to_transformers(sd, prefix):
|
||||
|
Reference in New Issue
Block a user