From 3412d53b1d69e4dfedf7e86c3092cea085094053 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 2 Sep 2025 12:36:22 -0700 Subject: [PATCH] USO style reference. (#9677) Load the projector.safetensors file with the ModelPatchLoader node and use the siglip_vision_patch14_384.safetensors "clip vision" model and the USOStyleReferenceNode. --- comfy/clip_model.py | 12 +- comfy/clip_vision.py | 18 ++- comfy/ldm/flux/model.py | 11 +- comfy/model_patcher.py | 3 + comfy_extras/nodes_model_patch.py | 186 +++++++++++++++++++++++++++++- 5 files changed, 222 insertions(+), 8 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 7e47d8a55..7c0cadab5 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -61,8 +61,12 @@ class CLIPEncoder(torch.nn.Module): def forward(self, x, mask=None, intermediate_output=None): optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) + all_intermediate = None if intermediate_output is not None: - if intermediate_output < 0: + if intermediate_output == "all": + all_intermediate = [] + intermediate_output = None + elif intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output intermediate = None @@ -70,6 +74,12 @@ class CLIPEncoder(torch.nn.Module): x = l(x, mask, optimized_attention) if i == intermediate_output: intermediate = x.clone() + if all_intermediate is not None: + all_intermediate.append(x.unsqueeze(1).clone()) + + if all_intermediate is not None: + intermediate = torch.cat(all_intermediate, dim=1) + return x, intermediate class CLIPEmbeddings(torch.nn.Module): diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 00aab9164..2fa410cb7 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -50,7 +50,13 @@ class ClipVisionModel(): self.image_size = config.get("image_size", 224) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) - model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model")) + model_type = config.get("model_type", "clip_vision_model") + model_class = IMAGE_ENCODERS.get(model_type) + if model_type == "siglip_vision_model": + self.return_all_hidden_states = True + else: + self.return_all_hidden_states = False + self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) @@ -68,12 +74,18 @@ class ClipVisionModel(): def encode_image(self, image, crop=True): comfy.model_management.load_model_gpu(self.patcher) pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() - out = self.model(pixel_values=pixel_values, intermediate_output=-2) + out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) outputs = Output() outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device()) - outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) + if self.return_all_hidden_states: + all_hs = out[1].to(comfy.model_management.intermediate_device()) + outputs["penultimate_hidden_states"] = all_hs[:, -2] + outputs["all_hidden_states"] = all_hs + else: + outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) + outputs["mm_projected"] = out[3] return outputs diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 1e62f4626..d4be6bb61 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -106,6 +106,7 @@ class Flux(nn.Module): if y is None: y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + patches = transformer_options.get("patches", {}) patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -117,9 +118,17 @@ class Flux(nn.Module): if guidance is not None: vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) - vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) txt = self.txt_in(txt) + if "post_input" in patches: + for p in patches["post_input"]: + out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) + img = out["img"] + txt = out["txt"] + img_ids = out["img_ids"] + txt_ids = out["txt_ids"] + if img_ids is not None: ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a944cb421..1fd03d9d1 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -433,6 +433,9 @@ class ModelPatcher: def set_model_double_block_patch(self, patch): self.set_model_patch(patch, "double_block") + def set_model_post_input_patch(self, patch): + self.set_model_patch(patch, "post_input") + def add_object_patch(self, name, obj): self.object_patches[name] = obj diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 65e766b52..783c59b6b 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -1,4 +1,5 @@ import torch +from torch import nn import folder_paths import comfy.utils import comfy.ops @@ -58,6 +59,136 @@ class QwenImageBlockWiseControlNet(torch.nn.Module): return self.controlnet_blocks[block_id](img, controlnet_conditioning) +class SigLIPMultiFeatProjModel(torch.nn.Module): + """ + SigLIP Multi-Feature Projection Model for processing style features from different layers + and projecting them into a unified hidden space. + + Args: + siglip_token_nums (int): Number of SigLIP tokens, default 257 + style_token_nums (int): Number of style tokens, default 256 + siglip_token_dims (int): Dimension of SigLIP tokens, default 1536 + hidden_size (int): Hidden layer size, default 3072 + context_layer_norm (bool): Whether to use context layer normalization, default False + """ + + def __init__( + self, + siglip_token_nums: int = 729, + style_token_nums: int = 64, + siglip_token_dims: int = 1152, + hidden_size: int = 3072, + context_layer_norm: bool = True, + device=None, dtype=None, operations=None + ): + super().__init__() + + # High-level feature processing (layer -2) + self.high_embedding_linear = nn.Sequential( + operations.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.high_layer_norm = ( + operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.high_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True) + + # Mid-level feature processing (layer -11) + self.mid_embedding_linear = nn.Sequential( + operations.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.mid_layer_norm = ( + operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.mid_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True) + + # Low-level feature processing (layer -20) + self.low_embedding_linear = nn.Sequential( + operations.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.low_layer_norm = ( + operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.low_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True) + + def forward(self, siglip_outputs): + """ + Forward pass function + + Args: + siglip_outputs: Output from SigLIP model, containing hidden_states + + Returns: + torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size] + """ + dtype = next(self.high_embedding_linear.parameters()).dtype + + # Process high-level features (layer -2) + high_embedding = self._process_layer_features( + siglip_outputs[2], + self.high_embedding_linear, + self.high_layer_norm, + self.high_projection, + dtype + ) + + # Process mid-level features (layer -11) + mid_embedding = self._process_layer_features( + siglip_outputs[1], + self.mid_embedding_linear, + self.mid_layer_norm, + self.mid_projection, + dtype + ) + + # Process low-level features (layer -20) + low_embedding = self._process_layer_features( + siglip_outputs[0], + self.low_embedding_linear, + self.low_layer_norm, + self.low_projection, + dtype + ) + + # Concatenate features from all layersmodel_patch + return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1) + + def _process_layer_features( + self, + hidden_states: torch.Tensor, + embedding_linear: nn.Module, + layer_norm: nn.Module, + projection: nn.Module, + dtype: torch.dtype + ) -> torch.Tensor: + """ + Helper function to process features from a single layer + + Args: + hidden_states: Input hidden states [bs, seq_len, dim] + embedding_linear: Embedding linear layer + layer_norm: Layer normalization + projection: Projection layer + dtype: Target data type + + Returns: + torch.Tensor: Processed features [bs, style_token_nums, hidden_size] + """ + # Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim] + embedding = embedding_linear( + hidden_states.to(dtype).transpose(1, 2) + ).transpose(1, 2) + + # Apply layer normalization + embedding = layer_norm(embedding) + + # Project to target hidden space + embedding = projection(embedding) + + return embedding + class ModelPatchLoader: @classmethod def INPUT_TYPES(s): @@ -73,9 +204,14 @@ class ModelPatchLoader: model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name) sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True) dtype = comfy.utils.weight_dtype(sd) - # TODO: this node will work with more types of model patches - additional_in_dim = sd["img_in.weight"].shape[1] - 64 - model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + + if 'controlnet_blocks.0.y_rms.weight' in sd: + additional_in_dim = sd["img_in.weight"].shape[1] - 64 + model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + elif 'feature_embedder.mid_layer_norm.bias' in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True) + model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + model.load_state_dict(sd) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) return (model,) @@ -157,7 +293,51 @@ class QwenImageDiffsynthControlnet: return (model_patched,) +class UsoStyleProjectorPatch: + def __init__(self, model_patch, encoded_image): + self.model_patch = model_patch + self.encoded_image = encoded_image + + def __call__(self, kwargs): + txt_ids = kwargs.get("txt_ids") + txt = kwargs.get("txt") + siglip_embedding = self.model_patch.model(self.encoded_image.to(txt.dtype)).to(txt.dtype) + txt = torch.cat([siglip_embedding, txt], dim=1) + kwargs['txt'] = txt + kwargs['txt_ids'] = torch.cat([torch.zeros(siglip_embedding.shape[0], siglip_embedding.shape[1], 3, dtype=txt_ids.dtype, device=txt_ids.device), txt_ids], dim=1) + return kwargs + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + self.encoded_image = self.encoded_image.to(device_or_dtype) + return self + + def models(self): + return [self.model_patch] + + +class USOStyleReference: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL",), + "model_patch": ("MODEL_PATCH",), + "clip_vision_output": ("CLIP_VISION_OUTPUT", ), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply_patch" + EXPERIMENTAL = True + + CATEGORY = "advanced/model_patches/flux" + + def apply_patch(self, model, model_patch, clip_vision_output): + encoded_image = torch.stack((clip_vision_output.all_hidden_states[:, -20], clip_vision_output.all_hidden_states[:, -11], clip_vision_output.penultimate_hidden_states)) + model_patched = model.clone() + model_patched.set_model_post_input_patch(UsoStyleProjectorPatch(model_patch, encoded_image)) + return (model_patched,) + + NODE_CLASS_MAPPINGS = { "ModelPatchLoader": ModelPatchLoader, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, + "USOStyleReference": USOStyleReference, }