From 0963493a9c3b6565f8537288a0fb90991391ec41 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 20 Aug 2025 19:26:37 -0700 Subject: [PATCH] Support for Qwen Diffsynth Controlnets canny and depth. (#9465) These are not real controlnets but actually a patch on the model so they will be treated as such. Put them in the models/model_patches/ folder. Use the new ModelPatchLoader and QwenImageDiffsynthControlnet nodes. --- comfy/ldm/qwen_image/model.py | 7 + comfy/model_management.py | 8 +- comfy/model_patcher.py | 27 ++++ comfy_api/latest/_io.py | 4 + comfy_extras/nodes_model_patch.py | 138 ++++++++++++++++++++ models/model_patches/put_model_patches_here | 0 nodes.py | 1 + 7 files changed, 184 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_model_patch.py create mode 100644 models/model_patches/put_model_patches_here diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 49f66b90a..2503583cb 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -416,6 +416,7 @@ class QwenImageTransformer2DModel(nn.Module): ) patches_replace = transformer_options.get("patches_replace", {}) + patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.transformer_blocks): @@ -436,6 +437,12 @@ class QwenImageTransformer2DModel(nn.Module): image_rotary_emb=image_rotary_emb, ) + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2a9f18068..d08aee1fe 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -593,7 +593,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu else: minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory()) - models = set(models) + models_temp = set() + for m in models: + models_temp.add(m) + for mm in m.model_patches_models(): + models_temp.add(mm) + + models = models_temp models_to_load = [] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 52e76b5f3..a944cb421 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -430,6 +430,9 @@ class ModelPatcher: def set_model_forward_timestep_embed_patch(self, patch): self.set_model_patch(patch, "forward_timestep_embed_patch") + def set_model_double_block_patch(self, patch): + self.set_model_patch(patch, "double_block") + def add_object_patch(self, name, obj): self.object_patches[name] = obj @@ -486,6 +489,30 @@ class ModelPatcher: if hasattr(wrap_func, "to"): self.model_options["model_function_wrapper"] = wrap_func.to(device) + def model_patches_models(self): + to = self.model_options["transformer_options"] + models = [] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "models"): + models += patch_list[i].models() + if "patches_replace" in to: + patches = to["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], "models"): + models += patch_list[k].models() + if "model_function_wrapper" in self.model_options: + wrap_func = self.model_options["model_function_wrapper"] + if hasattr(wrap_func, "models"): + models += wrap_func.models() + + return models + def model_dtype(self): if hasattr(self.model, "get_dtype"): return self.model.get_dtype() diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index ec1efb51d..a3a21facc 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -726,6 +726,10 @@ class SEGS(ComfyTypeIO): class AnyType(ComfyTypeIO): Type = Any +@comfytype(io_type="MODEL_PATCH") +class MODEL_PATCH(ComfyTypeIO): + Type = Any + @comfytype(io_type="COMFY_MULTITYPED_V3") class MultiType: Type = Any diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py new file mode 100644 index 000000000..bb239bc45 --- /dev/null +++ b/comfy_extras/nodes_model_patch.py @@ -0,0 +1,138 @@ +import torch +import folder_paths +import comfy.utils +import comfy.ops +import comfy.model_management +import comfy.ldm.common_dit +import comfy.latent_formats + + +class BlockWiseControlBlock(torch.nn.Module): + # [linear, gelu, linear] + def __init__(self, dim: int = 3072, device=None, dtype=None, operations=None): + super().__init__() + self.x_rms = operations.RMSNorm(dim, eps=1e-6) + self.y_rms = operations.RMSNorm(dim, eps=1e-6) + self.input_proj = operations.Linear(dim, dim) + self.act = torch.nn.GELU() + self.output_proj = operations.Linear(dim, dim) + + def forward(self, x, y): + x, y = self.x_rms(x), self.y_rms(y) + x = self.input_proj(x + y) + x = self.act(x) + x = self.output_proj(x) + return x + + +class QwenImageBlockWiseControlNet(torch.nn.Module): + def __init__( + self, + num_layers: int = 60, + in_dim: int = 64, + additional_in_dim: int = 0, + dim: int = 3072, + device=None, dtype=None, operations=None + ): + super().__init__() + self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype) + self.controlnet_blocks = torch.nn.ModuleList( + [ + BlockWiseControlBlock(dim, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ] + ) + + def process_input_latent_image(self, latent_image): + latent_image = comfy.latent_formats.Wan21().process_in(latent_image) + patch_size = 2 + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size)) + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) + hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + return self.img_in(hidden_states) + + def control_block(self, img, controlnet_conditioning, block_id): + return self.controlnet_blocks[block_id](img, controlnet_conditioning) + + +class ModelPatchLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "name": (folder_paths.get_filename_list("model_patches"), ), + }} + RETURN_TYPES = ("MODEL_PATCH",) + FUNCTION = "load_model_patch" + EXPERIMENTAL = True + + CATEGORY = "advanced/loaders" + + def load_model_patch(self, name): + model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name) + sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True) + dtype = comfy.utils.weight_dtype(sd) + # TODO: this node will work with more types of model patches + model = QwenImageBlockWiseControlNet(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + model.load_state_dict(sd) + model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + return (model,) + + +class DiffSynthCnetPatch: + def __init__(self, model_patch, vae, image, strength): + self.encoded_image = model_patch.model.process_input_latent_image(vae.encode(image)) + self.model_patch = model_patch + self.vae = vae + self.image = image + self.strength = strength + + def __call__(self, kwargs): + x = kwargs.get("x") + img = kwargs.get("img") + block_index = kwargs.get("block_index") + if self.encoded_image is None or self.encoded_image.shape[1:] != img.shape[1:]: + spacial_compression = self.vae.spacial_compression_encode() + image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + self.encoded_image = self.model_patch.model.process_input_latent_image(self.vae.encode(image_scaled.movedim(1, -1))) + comfy.model_management.load_models_gpu(loaded_models) + + img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength) + kwargs['img'] = img + return kwargs + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + self.encoded_image = self.encoded_image.to(device_or_dtype) + return self + + def models(self): + return [self.model_patch] + +class QwenImageDiffsynthControlnet: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "model_patch": ("MODEL_PATCH",), + "vae": ("VAE",), + "image": ("IMAGE",), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "diffsynth_controlnet" + EXPERIMENTAL = True + + CATEGORY = "advanced/loaders/qwen" + + def diffsynth_controlnet(self, model, model_patch, vae, image, strength): + model_patched = model.clone() + image = image[:, :, :, :3] + model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength)) + return (model_patched,) + + +NODE_CLASS_MAPPINGS = { + "ModelPatchLoader": ModelPatchLoader, + "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, +} diff --git a/models/model_patches/put_model_patches_here b/models/model_patches/put_model_patches_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 35dda1b19..9681750d3 100644 --- a/nodes.py +++ b/nodes.py @@ -2322,6 +2322,7 @@ async def init_builtin_extra_nodes(): "nodes_tcfg.py", "nodes_context_windows.py", "nodes_qwen.py", + "nodes_model_patch.py" ] import_failed = []