From f7bd5e58dd03e799e02f6851b84b51e14ad0da7b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 21 Aug 2025 20:18:04 -0700 Subject: [PATCH] Make it easier to implement future qwen controlnets. (#9485) --- comfy/controlnet.py | 4 ++-- comfy/ldm/qwen_image/model.py | 16 +++++++++++++--- comfy/model_detection.py | 2 ++ 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 988acdb57..6cb69dcdf 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -236,11 +236,11 @@ class ControlNet(ControlBase): self.cond_hint = None compression_ratio = self.compression_ratio if self.vae is not None: - compression_ratio *= self.vae.downscale_ratio + compression_ratio *= self.vae.spacial_compression_encode() else: if self.latent_format is not None: raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.") - self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center") self.cond_hint = self.preprocess_image(self.cond_hint) if self.vae is not None: loaded_models = comfy.model_management.loaded_models(only_currently_used=True) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 2503583cb..d0e39833a 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -293,6 +293,7 @@ class QwenImageTransformer2DModel(nn.Module): guidance_embeds: bool = False, axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), image_model=None, + final_layer=True, dtype=None, device=None, operations=None, @@ -300,6 +301,7 @@ class QwenImageTransformer2DModel(nn.Module): super().__init__() self.dtype = dtype self.patch_size = patch_size + self.in_channels = in_channels self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim @@ -329,9 +331,9 @@ class QwenImageTransformer2DModel(nn.Module): for _ in range(num_layers) ]) - self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) - self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) - self.gradient_checkpointing = False + if final_layer: + self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) + self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) def process_img(self, x, index=0, h_offset=0, w_offset=0): bs, c, t, h, w = x.shape @@ -362,6 +364,7 @@ class QwenImageTransformer2DModel(nn.Module): guidance: torch.Tensor = None, ref_latents=None, transformer_options={}, + control=None, **kwargs ): timestep = timesteps @@ -443,6 +446,13 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = out["img"] encoder_hidden_states = out["txt"] + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + hidden_states += add + hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 2bec0541e..0caff53e0 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -492,6 +492,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image dit_config = {} dit_config["image_model"] = "qwen_image" + dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1] + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') return dit_config if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: