mirror of
				https://github.com/comfyanonymous/ComfyUI.git
				synced 2025-10-25 07:54:30 +00:00 
			
		
		
		
	USO style reference. (#9677)
Load the projector.safetensors file with the ModelPatchLoader node and use the siglip_vision_patch14_384.safetensors "clip vision" model and the USOStyleReferenceNode.
This commit is contained in:
		| @@ -61,8 +61,12 @@ class CLIPEncoder(torch.nn.Module): | ||||
|     def forward(self, x, mask=None, intermediate_output=None): | ||||
|         optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) | ||||
|  | ||||
|         all_intermediate = None | ||||
|         if intermediate_output is not None: | ||||
|             if intermediate_output < 0: | ||||
|             if intermediate_output == "all": | ||||
|                 all_intermediate = [] | ||||
|                 intermediate_output = None | ||||
|             elif intermediate_output < 0: | ||||
|                 intermediate_output = len(self.layers) + intermediate_output | ||||
|  | ||||
|         intermediate = None | ||||
| @@ -70,6 +74,12 @@ class CLIPEncoder(torch.nn.Module): | ||||
|             x = l(x, mask, optimized_attention) | ||||
|             if i == intermediate_output: | ||||
|                 intermediate = x.clone() | ||||
|             if all_intermediate is not None: | ||||
|                 all_intermediate.append(x.unsqueeze(1).clone()) | ||||
|  | ||||
|         if all_intermediate is not None: | ||||
|             intermediate = torch.cat(all_intermediate, dim=1) | ||||
|  | ||||
|         return x, intermediate | ||||
|  | ||||
| class CLIPEmbeddings(torch.nn.Module): | ||||
|   | ||||
| @@ -50,7 +50,13 @@ class ClipVisionModel(): | ||||
|         self.image_size = config.get("image_size", 224) | ||||
|         self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) | ||||
|         self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) | ||||
|         model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model")) | ||||
|         model_type = config.get("model_type", "clip_vision_model") | ||||
|         model_class = IMAGE_ENCODERS.get(model_type) | ||||
|         if model_type == "siglip_vision_model": | ||||
|             self.return_all_hidden_states = True | ||||
|         else: | ||||
|             self.return_all_hidden_states = False | ||||
|  | ||||
|         self.load_device = comfy.model_management.text_encoder_device() | ||||
|         offload_device = comfy.model_management.text_encoder_offload_device() | ||||
|         self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) | ||||
| @@ -68,12 +74,18 @@ class ClipVisionModel(): | ||||
|     def encode_image(self, image, crop=True): | ||||
|         comfy.model_management.load_model_gpu(self.patcher) | ||||
|         pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() | ||||
|         out = self.model(pixel_values=pixel_values, intermediate_output=-2) | ||||
|         out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) | ||||
|  | ||||
|         outputs = Output() | ||||
|         outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) | ||||
|         outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device()) | ||||
|         outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) | ||||
|         if self.return_all_hidden_states: | ||||
|             all_hs = out[1].to(comfy.model_management.intermediate_device()) | ||||
|             outputs["penultimate_hidden_states"] = all_hs[:, -2] | ||||
|             outputs["all_hidden_states"] = all_hs | ||||
|         else: | ||||
|             outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) | ||||
|  | ||||
|         outputs["mm_projected"] = out[3] | ||||
|         return outputs | ||||
|  | ||||
|   | ||||
| @@ -106,6 +106,7 @@ class Flux(nn.Module): | ||||
|         if y is None: | ||||
|             y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) | ||||
|  | ||||
|         patches = transformer_options.get("patches", {}) | ||||
|         patches_replace = transformer_options.get("patches_replace", {}) | ||||
|         if img.ndim != 3 or txt.ndim != 3: | ||||
|             raise ValueError("Input img and txt tensors must have 3 dimensions.") | ||||
| @@ -117,9 +118,17 @@ class Flux(nn.Module): | ||||
|             if guidance is not None: | ||||
|                 vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) | ||||
|  | ||||
|         vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) | ||||
|         vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) | ||||
|         txt = self.txt_in(txt) | ||||
|  | ||||
|         if "post_input" in patches: | ||||
|             for p in patches["post_input"]: | ||||
|                 out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) | ||||
|                 img = out["img"] | ||||
|                 txt = out["txt"] | ||||
|                 img_ids = out["img_ids"] | ||||
|                 txt_ids = out["txt_ids"] | ||||
|  | ||||
|         if img_ids is not None: | ||||
|             ids = torch.cat((txt_ids, img_ids), dim=1) | ||||
|             pe = self.pe_embedder(ids) | ||||
|   | ||||
| @@ -433,6 +433,9 @@ class ModelPatcher: | ||||
|     def set_model_double_block_patch(self, patch): | ||||
|         self.set_model_patch(patch, "double_block") | ||||
|  | ||||
|     def set_model_post_input_patch(self, patch): | ||||
|         self.set_model_patch(patch, "post_input") | ||||
|  | ||||
|     def add_object_patch(self, name, obj): | ||||
|         self.object_patches[name] = obj | ||||
|  | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| import torch | ||||
| from torch import nn | ||||
| import folder_paths | ||||
| import comfy.utils | ||||
| import comfy.ops | ||||
| @@ -58,6 +59,136 @@ class QwenImageBlockWiseControlNet(torch.nn.Module): | ||||
|         return self.controlnet_blocks[block_id](img, controlnet_conditioning) | ||||
|  | ||||
|  | ||||
| class SigLIPMultiFeatProjModel(torch.nn.Module): | ||||
|     """ | ||||
|     SigLIP Multi-Feature Projection Model for processing style features from different layers | ||||
|     and projecting them into a unified hidden space. | ||||
|  | ||||
|     Args: | ||||
|         siglip_token_nums (int): Number of SigLIP tokens, default 257 | ||||
|         style_token_nums (int): Number of style tokens, default 256 | ||||
|         siglip_token_dims (int): Dimension of SigLIP tokens, default 1536 | ||||
|         hidden_size (int): Hidden layer size, default 3072 | ||||
|         context_layer_norm (bool): Whether to use context layer normalization, default False | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         siglip_token_nums: int = 729, | ||||
|         style_token_nums: int = 64, | ||||
|         siglip_token_dims: int = 1152, | ||||
|         hidden_size: int = 3072, | ||||
|         context_layer_norm: bool = True, | ||||
|         device=None, dtype=None, operations=None | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | ||||
|         # High-level feature processing (layer -2) | ||||
|         self.high_embedding_linear = nn.Sequential( | ||||
|             operations.Linear(siglip_token_nums, style_token_nums), | ||||
|             nn.SiLU() | ||||
|         ) | ||||
|         self.high_layer_norm = ( | ||||
|             operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() | ||||
|         ) | ||||
|         self.high_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True) | ||||
|  | ||||
|         # Mid-level feature processing (layer -11) | ||||
|         self.mid_embedding_linear = nn.Sequential( | ||||
|             operations.Linear(siglip_token_nums, style_token_nums), | ||||
|             nn.SiLU() | ||||
|         ) | ||||
|         self.mid_layer_norm = ( | ||||
|             operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() | ||||
|         ) | ||||
|         self.mid_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True) | ||||
|  | ||||
|         # Low-level feature processing (layer -20) | ||||
|         self.low_embedding_linear = nn.Sequential( | ||||
|             operations.Linear(siglip_token_nums, style_token_nums), | ||||
|             nn.SiLU() | ||||
|         ) | ||||
|         self.low_layer_norm = ( | ||||
|             operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() | ||||
|         ) | ||||
|         self.low_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True) | ||||
|  | ||||
|     def forward(self, siglip_outputs): | ||||
|         """ | ||||
|         Forward pass function | ||||
|  | ||||
|         Args: | ||||
|             siglip_outputs: Output from SigLIP model, containing hidden_states | ||||
|  | ||||
|         Returns: | ||||
|             torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size] | ||||
|         """ | ||||
|         dtype = next(self.high_embedding_linear.parameters()).dtype | ||||
|  | ||||
|         # Process high-level features (layer -2) | ||||
|         high_embedding = self._process_layer_features( | ||||
|             siglip_outputs[2], | ||||
|             self.high_embedding_linear, | ||||
|             self.high_layer_norm, | ||||
|             self.high_projection, | ||||
|             dtype | ||||
|         ) | ||||
|  | ||||
|         # Process mid-level features (layer -11) | ||||
|         mid_embedding = self._process_layer_features( | ||||
|             siglip_outputs[1], | ||||
|             self.mid_embedding_linear, | ||||
|             self.mid_layer_norm, | ||||
|             self.mid_projection, | ||||
|             dtype | ||||
|         ) | ||||
|  | ||||
|         # Process low-level features (layer -20) | ||||
|         low_embedding = self._process_layer_features( | ||||
|             siglip_outputs[0], | ||||
|             self.low_embedding_linear, | ||||
|             self.low_layer_norm, | ||||
|             self.low_projection, | ||||
|             dtype | ||||
|         ) | ||||
|  | ||||
|         # Concatenate features from all layersmodel_patch | ||||
|         return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1) | ||||
|  | ||||
|     def _process_layer_features( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         embedding_linear: nn.Module, | ||||
|         layer_norm: nn.Module, | ||||
|         projection: nn.Module, | ||||
|         dtype: torch.dtype | ||||
|     ) -> torch.Tensor: | ||||
|         """ | ||||
|         Helper function to process features from a single layer | ||||
|  | ||||
|         Args: | ||||
|             hidden_states: Input hidden states [bs, seq_len, dim] | ||||
|             embedding_linear: Embedding linear layer | ||||
|             layer_norm: Layer normalization | ||||
|             projection: Projection layer | ||||
|             dtype: Target data type | ||||
|  | ||||
|         Returns: | ||||
|             torch.Tensor: Processed features [bs, style_token_nums, hidden_size] | ||||
|         """ | ||||
|         # Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim] | ||||
|         embedding = embedding_linear( | ||||
|             hidden_states.to(dtype).transpose(1, 2) | ||||
|         ).transpose(1, 2) | ||||
|  | ||||
|         # Apply layer normalization | ||||
|         embedding = layer_norm(embedding) | ||||
|  | ||||
|         # Project to target hidden space | ||||
|         embedding = projection(embedding) | ||||
|  | ||||
|         return embedding | ||||
|  | ||||
| class ModelPatchLoader: | ||||
|     @classmethod | ||||
|     def INPUT_TYPES(s): | ||||
| @@ -73,9 +204,14 @@ class ModelPatchLoader: | ||||
|         model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name) | ||||
|         sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True) | ||||
|         dtype = comfy.utils.weight_dtype(sd) | ||||
|         # TODO: this node will work with more types of model patches | ||||
|         additional_in_dim = sd["img_in.weight"].shape[1] - 64 | ||||
|         model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) | ||||
|  | ||||
|         if 'controlnet_blocks.0.y_rms.weight' in sd: | ||||
|             additional_in_dim = sd["img_in.weight"].shape[1] - 64 | ||||
|             model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) | ||||
|         elif 'feature_embedder.mid_layer_norm.bias' in sd: | ||||
|             sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True) | ||||
|             model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) | ||||
|  | ||||
|         model.load_state_dict(sd) | ||||
|         model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) | ||||
|         return (model,) | ||||
| @@ -157,7 +293,51 @@ class QwenImageDiffsynthControlnet: | ||||
|         return (model_patched,) | ||||
|  | ||||
|  | ||||
| class UsoStyleProjectorPatch: | ||||
|     def __init__(self, model_patch, encoded_image): | ||||
|         self.model_patch = model_patch | ||||
|         self.encoded_image = encoded_image | ||||
|  | ||||
|     def __call__(self, kwargs): | ||||
|         txt_ids = kwargs.get("txt_ids") | ||||
|         txt = kwargs.get("txt") | ||||
|         siglip_embedding = self.model_patch.model(self.encoded_image.to(txt.dtype)).to(txt.dtype) | ||||
|         txt = torch.cat([siglip_embedding, txt], dim=1) | ||||
|         kwargs['txt'] = txt | ||||
|         kwargs['txt_ids'] = torch.cat([torch.zeros(siglip_embedding.shape[0], siglip_embedding.shape[1], 3, dtype=txt_ids.dtype, device=txt_ids.device), txt_ids], dim=1) | ||||
|         return kwargs | ||||
|  | ||||
|     def to(self, device_or_dtype): | ||||
|         if isinstance(device_or_dtype, torch.device): | ||||
|             self.encoded_image = self.encoded_image.to(device_or_dtype) | ||||
|         return self | ||||
|  | ||||
|     def models(self): | ||||
|         return [self.model_patch] | ||||
|  | ||||
|  | ||||
| class USOStyleReference: | ||||
|     @classmethod | ||||
|     def INPUT_TYPES(s): | ||||
|         return {"required": {"model": ("MODEL",), | ||||
|                              "model_patch": ("MODEL_PATCH",), | ||||
|                              "clip_vision_output": ("CLIP_VISION_OUTPUT", ), | ||||
|                               }} | ||||
|     RETURN_TYPES = ("MODEL",) | ||||
|     FUNCTION = "apply_patch" | ||||
|     EXPERIMENTAL = True | ||||
|  | ||||
|     CATEGORY = "advanced/model_patches/flux" | ||||
|  | ||||
|     def apply_patch(self, model, model_patch, clip_vision_output): | ||||
|         encoded_image = torch.stack((clip_vision_output.all_hidden_states[:, -20], clip_vision_output.all_hidden_states[:, -11], clip_vision_output.penultimate_hidden_states)) | ||||
|         model_patched = model.clone() | ||||
|         model_patched.set_model_post_input_patch(UsoStyleProjectorPatch(model_patch, encoded_image)) | ||||
|         return (model_patched,) | ||||
|  | ||||
|  | ||||
| NODE_CLASS_MAPPINGS = { | ||||
|     "ModelPatchLoader": ModelPatchLoader, | ||||
|     "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, | ||||
|     "USOStyleReference": USOStyleReference, | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user