From 1b2de2642d38099acdde7c460d133d93e91074f0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 20 Aug 2025 21:33:49 -0700 Subject: [PATCH] Support diffsynth inpaint controlnet (model patch). (#9471) --- comfy_extras/nodes_model_patch.py | 39 ++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index bb239bc45..3eaada9bc 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -35,6 +35,7 @@ class QwenImageBlockWiseControlNet(torch.nn.Module): device=None, dtype=None, operations=None ): super().__init__() + self.additional_in_dim = additional_in_dim self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype) self.controlnet_blocks = torch.nn.ModuleList( [ @@ -44,7 +45,7 @@ class QwenImageBlockWiseControlNet(torch.nn.Module): ) def process_input_latent_image(self, latent_image): - latent_image = comfy.latent_formats.Wan21().process_in(latent_image) + latent_image[:, :16] = comfy.latent_formats.Wan21().process_in(latent_image[:, :16]) patch_size = 2 hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size)) orig_shape = hidden_states.shape @@ -73,19 +74,33 @@ class ModelPatchLoader: sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True) dtype = comfy.utils.weight_dtype(sd) # TODO: this node will work with more types of model patches - model = QwenImageBlockWiseControlNet(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + additional_in_dim = sd["img_in.weight"].shape[1] - 64 + model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) model.load_state_dict(sd) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) return (model,) class DiffSynthCnetPatch: - def __init__(self, model_patch, vae, image, strength): - self.encoded_image = model_patch.model.process_input_latent_image(vae.encode(image)) + def __init__(self, model_patch, vae, image, strength, mask=None): self.model_patch = model_patch self.vae = vae self.image = image self.strength = strength + self.mask = mask + self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image)) + + def encode_latent_cond(self, image): + latent_image = self.vae.encode(image) + if self.model_patch.model.additional_in_dim > 0: + if self.mask is None: + mask_ = torch.ones_like(latent_image)[:, :self.model_patch.model.additional_in_dim // 4] + else: + mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none") + + return torch.cat([latent_image, mask_], dim=1) + else: + return latent_image def __call__(self, kwargs): x = kwargs.get("x") @@ -95,7 +110,7 @@ class DiffSynthCnetPatch: spacial_compression = self.vae.spacial_compression_encode() image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") loaded_models = comfy.model_management.loaded_models(only_currently_used=True) - self.encoded_image = self.model_patch.model.process_input_latent_image(self.vae.encode(image_scaled.movedim(1, -1))) + self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1))) comfy.model_management.load_models_gpu(loaded_models) img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength) @@ -118,17 +133,25 @@ class QwenImageDiffsynthControlnet: "vae": ("VAE",), "image": ("IMAGE",), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - }} + }, + "optional": {"mask": ("MASK",)}} RETURN_TYPES = ("MODEL",) FUNCTION = "diffsynth_controlnet" EXPERIMENTAL = True CATEGORY = "advanced/loaders/qwen" - def diffsynth_controlnet(self, model, model_patch, vae, image, strength): + def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None): model_patched = model.clone() image = image[:, :, :, :3] - model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength)) + if mask is not None: + if mask.ndim == 3: + mask = mask.unsqueeze(1) + if mask.ndim == 4: + mask = mask.unsqueeze(2) + mask = 1.0 - mask + + model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) return (model_patched,)