mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
USO style reference. (#9677)
Load the projector.safetensors file with the ModelPatchLoader node and use the siglip_vision_patch14_384.safetensors "clip vision" model and the USOStyleReferenceNode.
This commit is contained in:
@@ -61,8 +61,12 @@ class CLIPEncoder(torch.nn.Module):
|
|||||||
def forward(self, x, mask=None, intermediate_output=None):
|
def forward(self, x, mask=None, intermediate_output=None):
|
||||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||||
|
|
||||||
|
all_intermediate = None
|
||||||
if intermediate_output is not None:
|
if intermediate_output is not None:
|
||||||
if intermediate_output < 0:
|
if intermediate_output == "all":
|
||||||
|
all_intermediate = []
|
||||||
|
intermediate_output = None
|
||||||
|
elif intermediate_output < 0:
|
||||||
intermediate_output = len(self.layers) + intermediate_output
|
intermediate_output = len(self.layers) + intermediate_output
|
||||||
|
|
||||||
intermediate = None
|
intermediate = None
|
||||||
@@ -70,6 +74,12 @@ class CLIPEncoder(torch.nn.Module):
|
|||||||
x = l(x, mask, optimized_attention)
|
x = l(x, mask, optimized_attention)
|
||||||
if i == intermediate_output:
|
if i == intermediate_output:
|
||||||
intermediate = x.clone()
|
intermediate = x.clone()
|
||||||
|
if all_intermediate is not None:
|
||||||
|
all_intermediate.append(x.unsqueeze(1).clone())
|
||||||
|
|
||||||
|
if all_intermediate is not None:
|
||||||
|
intermediate = torch.cat(all_intermediate, dim=1)
|
||||||
|
|
||||||
return x, intermediate
|
return x, intermediate
|
||||||
|
|
||||||
class CLIPEmbeddings(torch.nn.Module):
|
class CLIPEmbeddings(torch.nn.Module):
|
||||||
|
@@ -50,7 +50,13 @@ class ClipVisionModel():
|
|||||||
self.image_size = config.get("image_size", 224)
|
self.image_size = config.get("image_size", 224)
|
||||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||||
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
|
model_type = config.get("model_type", "clip_vision_model")
|
||||||
|
model_class = IMAGE_ENCODERS.get(model_type)
|
||||||
|
if model_type == "siglip_vision_model":
|
||||||
|
self.return_all_hidden_states = True
|
||||||
|
else:
|
||||||
|
self.return_all_hidden_states = False
|
||||||
|
|
||||||
self.load_device = comfy.model_management.text_encoder_device()
|
self.load_device = comfy.model_management.text_encoder_device()
|
||||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||||
@@ -68,12 +74,18 @@ class ClipVisionModel():
|
|||||||
def encode_image(self, image, crop=True):
|
def encode_image(self, image, crop=True):
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||||
|
|
||||||
outputs = Output()
|
outputs = Output()
|
||||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||||
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
||||||
|
if self.return_all_hidden_states:
|
||||||
|
all_hs = out[1].to(comfy.model_management.intermediate_device())
|
||||||
|
outputs["penultimate_hidden_states"] = all_hs[:, -2]
|
||||||
|
outputs["all_hidden_states"] = all_hs
|
||||||
|
else:
|
||||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
outputs["mm_projected"] = out[3]
|
outputs["mm_projected"] = out[3]
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@@ -106,6 +106,7 @@ class Flux(nn.Module):
|
|||||||
if y is None:
|
if y is None:
|
||||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||||
|
|
||||||
|
patches = transformer_options.get("patches", {})
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
@@ -117,9 +118,17 @@ class Flux(nn.Module):
|
|||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||||
|
|
||||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
if "post_input" in patches:
|
||||||
|
for p in patches["post_input"]:
|
||||||
|
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
||||||
|
img = out["img"]
|
||||||
|
txt = out["txt"]
|
||||||
|
img_ids = out["img_ids"]
|
||||||
|
txt_ids = out["txt_ids"]
|
||||||
|
|
||||||
if img_ids is not None:
|
if img_ids is not None:
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
|
@@ -433,6 +433,9 @@ class ModelPatcher:
|
|||||||
def set_model_double_block_patch(self, patch):
|
def set_model_double_block_patch(self, patch):
|
||||||
self.set_model_patch(patch, "double_block")
|
self.set_model_patch(patch, "double_block")
|
||||||
|
|
||||||
|
def set_model_post_input_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "post_input")
|
||||||
|
|
||||||
def add_object_patch(self, name, obj):
|
def add_object_patch(self, name, obj):
|
||||||
self.object_patches[name] = obj
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
@@ -58,6 +59,136 @@ class QwenImageBlockWiseControlNet(torch.nn.Module):
|
|||||||
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
|
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
|
||||||
|
|
||||||
|
|
||||||
|
class SigLIPMultiFeatProjModel(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
SigLIP Multi-Feature Projection Model for processing style features from different layers
|
||||||
|
and projecting them into a unified hidden space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
siglip_token_nums (int): Number of SigLIP tokens, default 257
|
||||||
|
style_token_nums (int): Number of style tokens, default 256
|
||||||
|
siglip_token_dims (int): Dimension of SigLIP tokens, default 1536
|
||||||
|
hidden_size (int): Hidden layer size, default 3072
|
||||||
|
context_layer_norm (bool): Whether to use context layer normalization, default False
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
siglip_token_nums: int = 729,
|
||||||
|
style_token_nums: int = 64,
|
||||||
|
siglip_token_dims: int = 1152,
|
||||||
|
hidden_size: int = 3072,
|
||||||
|
context_layer_norm: bool = True,
|
||||||
|
device=None, dtype=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# High-level feature processing (layer -2)
|
||||||
|
self.high_embedding_linear = nn.Sequential(
|
||||||
|
operations.Linear(siglip_token_nums, style_token_nums),
|
||||||
|
nn.SiLU()
|
||||||
|
)
|
||||||
|
self.high_layer_norm = (
|
||||||
|
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
|
||||||
|
)
|
||||||
|
self.high_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
|
||||||
|
|
||||||
|
# Mid-level feature processing (layer -11)
|
||||||
|
self.mid_embedding_linear = nn.Sequential(
|
||||||
|
operations.Linear(siglip_token_nums, style_token_nums),
|
||||||
|
nn.SiLU()
|
||||||
|
)
|
||||||
|
self.mid_layer_norm = (
|
||||||
|
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
|
||||||
|
)
|
||||||
|
self.mid_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
|
||||||
|
|
||||||
|
# Low-level feature processing (layer -20)
|
||||||
|
self.low_embedding_linear = nn.Sequential(
|
||||||
|
operations.Linear(siglip_token_nums, style_token_nums),
|
||||||
|
nn.SiLU()
|
||||||
|
)
|
||||||
|
self.low_layer_norm = (
|
||||||
|
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
|
||||||
|
)
|
||||||
|
self.low_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
|
||||||
|
|
||||||
|
def forward(self, siglip_outputs):
|
||||||
|
"""
|
||||||
|
Forward pass function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
siglip_outputs: Output from SigLIP model, containing hidden_states
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size]
|
||||||
|
"""
|
||||||
|
dtype = next(self.high_embedding_linear.parameters()).dtype
|
||||||
|
|
||||||
|
# Process high-level features (layer -2)
|
||||||
|
high_embedding = self._process_layer_features(
|
||||||
|
siglip_outputs[2],
|
||||||
|
self.high_embedding_linear,
|
||||||
|
self.high_layer_norm,
|
||||||
|
self.high_projection,
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process mid-level features (layer -11)
|
||||||
|
mid_embedding = self._process_layer_features(
|
||||||
|
siglip_outputs[1],
|
||||||
|
self.mid_embedding_linear,
|
||||||
|
self.mid_layer_norm,
|
||||||
|
self.mid_projection,
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process low-level features (layer -20)
|
||||||
|
low_embedding = self._process_layer_features(
|
||||||
|
siglip_outputs[0],
|
||||||
|
self.low_embedding_linear,
|
||||||
|
self.low_layer_norm,
|
||||||
|
self.low_projection,
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Concatenate features from all layersmodel_patch
|
||||||
|
return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1)
|
||||||
|
|
||||||
|
def _process_layer_features(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
embedding_linear: nn.Module,
|
||||||
|
layer_norm: nn.Module,
|
||||||
|
projection: nn.Module,
|
||||||
|
dtype: torch.dtype
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Helper function to process features from a single layer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states: Input hidden states [bs, seq_len, dim]
|
||||||
|
embedding_linear: Embedding linear layer
|
||||||
|
layer_norm: Layer normalization
|
||||||
|
projection: Projection layer
|
||||||
|
dtype: Target data type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Processed features [bs, style_token_nums, hidden_size]
|
||||||
|
"""
|
||||||
|
# Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim]
|
||||||
|
embedding = embedding_linear(
|
||||||
|
hidden_states.to(dtype).transpose(1, 2)
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
# Apply layer normalization
|
||||||
|
embedding = layer_norm(embedding)
|
||||||
|
|
||||||
|
# Project to target hidden space
|
||||||
|
embedding = projection(embedding)
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
class ModelPatchLoader:
|
class ModelPatchLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@@ -73,9 +204,14 @@ class ModelPatchLoader:
|
|||||||
model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name)
|
model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name)
|
||||||
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
|
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
|
||||||
dtype = comfy.utils.weight_dtype(sd)
|
dtype = comfy.utils.weight_dtype(sd)
|
||||||
# TODO: this node will work with more types of model patches
|
|
||||||
|
if 'controlnet_blocks.0.y_rms.weight' in sd:
|
||||||
additional_in_dim = sd["img_in.weight"].shape[1] - 64
|
additional_in_dim = sd["img_in.weight"].shape[1] - 64
|
||||||
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
|
elif 'feature_embedder.mid_layer_norm.bias' in sd:
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
|
||||||
|
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
|
|
||||||
model.load_state_dict(sd)
|
model.load_state_dict(sd)
|
||||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
return (model,)
|
return (model,)
|
||||||
@@ -157,7 +293,51 @@ class QwenImageDiffsynthControlnet:
|
|||||||
return (model_patched,)
|
return (model_patched,)
|
||||||
|
|
||||||
|
|
||||||
|
class UsoStyleProjectorPatch:
|
||||||
|
def __init__(self, model_patch, encoded_image):
|
||||||
|
self.model_patch = model_patch
|
||||||
|
self.encoded_image = encoded_image
|
||||||
|
|
||||||
|
def __call__(self, kwargs):
|
||||||
|
txt_ids = kwargs.get("txt_ids")
|
||||||
|
txt = kwargs.get("txt")
|
||||||
|
siglip_embedding = self.model_patch.model(self.encoded_image.to(txt.dtype)).to(txt.dtype)
|
||||||
|
txt = torch.cat([siglip_embedding, txt], dim=1)
|
||||||
|
kwargs['txt'] = txt
|
||||||
|
kwargs['txt_ids'] = torch.cat([torch.zeros(siglip_embedding.shape[0], siglip_embedding.shape[1], 3, dtype=txt_ids.dtype, device=txt_ids.device), txt_ids], dim=1)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def to(self, device_or_dtype):
|
||||||
|
if isinstance(device_or_dtype, torch.device):
|
||||||
|
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def models(self):
|
||||||
|
return [self.model_patch]
|
||||||
|
|
||||||
|
|
||||||
|
class USOStyleReference:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"model": ("MODEL",),
|
||||||
|
"model_patch": ("MODEL_PATCH",),
|
||||||
|
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "apply_patch"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
CATEGORY = "advanced/model_patches/flux"
|
||||||
|
|
||||||
|
def apply_patch(self, model, model_patch, clip_vision_output):
|
||||||
|
encoded_image = torch.stack((clip_vision_output.all_hidden_states[:, -20], clip_vision_output.all_hidden_states[:, -11], clip_vision_output.penultimate_hidden_states))
|
||||||
|
model_patched = model.clone()
|
||||||
|
model_patched.set_model_post_input_patch(UsoStyleProjectorPatch(model_patch, encoded_image))
|
||||||
|
return (model_patched,)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelPatchLoader": ModelPatchLoader,
|
"ModelPatchLoader": ModelPatchLoader,
|
||||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||||
|
"USOStyleReference": USOStyleReference,
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user