mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 20:17:30 +00:00
Support InstantX Qwen controlnet. (#9488)
This commit is contained in:
@@ -36,6 +36,7 @@ import comfy.ldm.cascade.controlnet
|
|||||||
import comfy.cldm.mmdit
|
import comfy.cldm.mmdit
|
||||||
import comfy.ldm.hydit.controlnet
|
import comfy.ldm.hydit.controlnet
|
||||||
import comfy.ldm.flux.controlnet
|
import comfy.ldm.flux.controlnet
|
||||||
|
import comfy.ldm.qwen_image.controlnet
|
||||||
import comfy.cldm.dit_embedder
|
import comfy.cldm.dit_embedder
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -582,6 +583,15 @@ def load_controlnet_flux_instantx(sd, model_options={}):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||||
|
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
|
latent_format = comfy.latent_formats.Wan21()
|
||||||
|
extra_conds = []
|
||||||
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
|
return control
|
||||||
|
|
||||||
def convert_mistoline(sd):
|
def convert_mistoline(sd):
|
||||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||||
|
|
||||||
@@ -655,8 +665,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|||||||
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
||||||
else:
|
else:
|
||||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||||
|
elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data:
|
||||||
|
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||||
|
|
||||||
|
77
comfy/ldm/qwen_image/controlnet.py
Normal file
77
comfy/ldm/qwen_image/controlnet.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
|
from .model import QwenImageTransformer2DModel
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
extra_condition_channels=0,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
self.main_model_double = 60
|
||||||
|
|
||||||
|
# controlnet_blocks
|
||||||
|
self.controlnet_blocks = torch.nn.ModuleList([])
|
||||||
|
for _ in range(len(self.transformer_blocks)):
|
||||||
|
self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype))
|
||||||
|
self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timesteps,
|
||||||
|
context,
|
||||||
|
attention_mask=None,
|
||||||
|
guidance: torch.Tensor = None,
|
||||||
|
ref_latents=None,
|
||||||
|
hint=None,
|
||||||
|
transformer_options={},
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
timestep = timesteps
|
||||||
|
encoder_hidden_states = context
|
||||||
|
encoder_hidden_states_mask = attention_mask
|
||||||
|
|
||||||
|
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||||
|
hint, _, _ = self.process_img(hint)
|
||||||
|
|
||||||
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||||
|
del ids, txt_ids, img_ids
|
||||||
|
|
||||||
|
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
|
||||||
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||||
|
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||||
|
|
||||||
|
if guidance is not None:
|
||||||
|
guidance = guidance * 1000
|
||||||
|
|
||||||
|
temb = (
|
||||||
|
self.time_text_embed(timestep, hidden_states)
|
||||||
|
if guidance is None
|
||||||
|
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||||
|
)
|
||||||
|
|
||||||
|
repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks))
|
||||||
|
|
||||||
|
controlnet_block_samples = ()
|
||||||
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
encoder_hidden_states, hidden_states = block(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
|
temb=temb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat
|
||||||
|
|
||||||
|
return {"input": controlnet_block_samples[:self.main_model_double]}
|
Reference in New Issue
Block a user