From ff57793659702d502506047445f0972b10b6b9fe Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 21 Aug 2025 21:53:11 -0700 Subject: [PATCH] Support InstantX Qwen controlnet. (#9488) --- comfy/controlnet.py | 13 +++++ comfy/ldm/qwen_image/controlnet.py | 77 ++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 comfy/ldm/qwen_image/controlnet.py diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 6cb69dcdf..e3dfedf55 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -36,6 +36,7 @@ import comfy.ldm.cascade.controlnet import comfy.cldm.mmdit import comfy.ldm.hydit.controlnet import comfy.ldm.flux.controlnet +import comfy.ldm.qwen_image.controlnet import comfy.cldm.dit_embedder from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -582,6 +583,15 @@ def load_controlnet_flux_instantx(sd, model_options={}): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control +def load_controlnet_qwen_instantx(sd, model_options={}): + model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options) + control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + control_model = controlnet_load_state_dict(control_model, sd) + latent_format = comfy.latent_formats.Wan21() + extra_conds = [] + control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) + return control + def convert_mistoline(sd): return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) @@ -655,8 +665,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format else: return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet + elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data: + return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options) elif "controlnet_x_embedder.weight" in controlnet_data: return load_controlnet_flux_instantx(controlnet_data, model_options=model_options) + elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options) diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py new file mode 100644 index 000000000..92ac3cf0a --- /dev/null +++ b/comfy/ldm/qwen_image/controlnet.py @@ -0,0 +1,77 @@ +import torch +import math + +from .model import QwenImageTransformer2DModel + + +class QwenImageControlNetModel(QwenImageTransformer2DModel): + def __init__( + self, + extra_condition_channels=0, + dtype=None, + device=None, + operations=None, + **kwargs + ): + super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) + self.main_model_double = 60 + + # controlnet_blocks + self.controlnet_blocks = torch.nn.ModuleList([]) + for _ in range(len(self.transformer_blocks)): + self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype)) + self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype) + + def forward( + self, + x, + timesteps, + context, + attention_mask=None, + guidance: torch.Tensor = None, + ref_latents=None, + hint=None, + transformer_options={}, + **kwargs + ): + timestep = timesteps + encoder_hidden_states = context + encoder_hidden_states_mask = attention_mask + + hidden_states, img_ids, orig_shape = self.process_img(x) + hint, _, _ = self.process_img(hint) + + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) + txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + del ids, txt_ids, img_ids + + hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks)) + + controlnet_block_samples = () + for i, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat + + return {"input": controlnet_block_samples[:self.main_model_double]}