mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
Support diffsynth inpaint controlnet (model patch). (#9471)
This commit is contained in:
@@ -35,6 +35,7 @@ class QwenImageBlockWiseControlNet(torch.nn.Module):
|
|||||||
device=None, dtype=None, operations=None
|
device=None, dtype=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.additional_in_dim = additional_in_dim
|
||||||
self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype)
|
self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype)
|
||||||
self.controlnet_blocks = torch.nn.ModuleList(
|
self.controlnet_blocks = torch.nn.ModuleList(
|
||||||
[
|
[
|
||||||
@@ -44,7 +45,7 @@ class QwenImageBlockWiseControlNet(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def process_input_latent_image(self, latent_image):
|
def process_input_latent_image(self, latent_image):
|
||||||
latent_image = comfy.latent_formats.Wan21().process_in(latent_image)
|
latent_image[:, :16] = comfy.latent_formats.Wan21().process_in(latent_image[:, :16])
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size))
|
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size))
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
@@ -73,19 +74,33 @@ class ModelPatchLoader:
|
|||||||
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
|
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
|
||||||
dtype = comfy.utils.weight_dtype(sd)
|
dtype = comfy.utils.weight_dtype(sd)
|
||||||
# TODO: this node will work with more types of model patches
|
# TODO: this node will work with more types of model patches
|
||||||
model = QwenImageBlockWiseControlNet(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
additional_in_dim = sd["img_in.weight"].shape[1] - 64
|
||||||
|
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
model.load_state_dict(sd)
|
model.load_state_dict(sd)
|
||||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
|
|
||||||
class DiffSynthCnetPatch:
|
class DiffSynthCnetPatch:
|
||||||
def __init__(self, model_patch, vae, image, strength):
|
def __init__(self, model_patch, vae, image, strength, mask=None):
|
||||||
self.encoded_image = model_patch.model.process_input_latent_image(vae.encode(image))
|
|
||||||
self.model_patch = model_patch
|
self.model_patch = model_patch
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
self.image = image
|
self.image = image
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
|
self.mask = mask
|
||||||
|
self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image))
|
||||||
|
|
||||||
|
def encode_latent_cond(self, image):
|
||||||
|
latent_image = self.vae.encode(image)
|
||||||
|
if self.model_patch.model.additional_in_dim > 0:
|
||||||
|
if self.mask is None:
|
||||||
|
mask_ = torch.ones_like(latent_image)[:, :self.model_patch.model.additional_in_dim // 4]
|
||||||
|
else:
|
||||||
|
mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
|
||||||
|
|
||||||
|
return torch.cat([latent_image, mask_], dim=1)
|
||||||
|
else:
|
||||||
|
return latent_image
|
||||||
|
|
||||||
def __call__(self, kwargs):
|
def __call__(self, kwargs):
|
||||||
x = kwargs.get("x")
|
x = kwargs.get("x")
|
||||||
@@ -95,7 +110,7 @@ class DiffSynthCnetPatch:
|
|||||||
spacial_compression = self.vae.spacial_compression_encode()
|
spacial_compression = self.vae.spacial_compression_encode()
|
||||||
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
self.encoded_image = self.model_patch.model.process_input_latent_image(self.vae.encode(image_scaled.movedim(1, -1)))
|
self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1)))
|
||||||
comfy.model_management.load_models_gpu(loaded_models)
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
|
|
||||||
img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength)
|
img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength)
|
||||||
@@ -118,17 +133,25 @@ class QwenImageDiffsynthControlnet:
|
|||||||
"vae": ("VAE",),
|
"vae": ("VAE",),
|
||||||
"image": ("IMAGE",),
|
"image": ("IMAGE",),
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||||
}}
|
},
|
||||||
|
"optional": {"mask": ("MASK",)}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "diffsynth_controlnet"
|
FUNCTION = "diffsynth_controlnet"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
CATEGORY = "advanced/loaders/qwen"
|
CATEGORY = "advanced/loaders/qwen"
|
||||||
|
|
||||||
def diffsynth_controlnet(self, model, model_patch, vae, image, strength):
|
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None):
|
||||||
model_patched = model.clone()
|
model_patched = model.clone()
|
||||||
image = image[:, :, :, :3]
|
image = image[:, :, :, :3]
|
||||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength))
|
if mask is not None:
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
if mask.ndim == 4:
|
||||||
|
mask = mask.unsqueeze(2)
|
||||||
|
mask = 1.0 - mask
|
||||||
|
|
||||||
|
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||||
return (model_patched,)
|
return (model_patched,)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user