mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 20:48:22 +00:00
Make it easier to implement future qwen controlnets. (#9485)
This commit is contained in:
@@ -236,11 +236,11 @@ class ControlNet(ControlBase):
|
|||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
compression_ratio = self.compression_ratio
|
compression_ratio = self.compression_ratio
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
compression_ratio *= self.vae.downscale_ratio
|
compression_ratio *= self.vae.spacial_compression_encode()
|
||||||
else:
|
else:
|
||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center")
|
||||||
self.cond_hint = self.preprocess_image(self.cond_hint)
|
self.cond_hint = self.preprocess_image(self.cond_hint)
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
|
@@ -293,6 +293,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
guidance_embeds: bool = False,
|
guidance_embeds: bool = False,
|
||||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||||
image_model=None,
|
image_model=None,
|
||||||
|
final_layer=True,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@@ -300,6 +301,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels or in_channels
|
self.out_channels = out_channels or in_channels
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
|
||||||
@@ -329,9 +331,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
if final_layer:
|
||||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||||
self.gradient_checkpointing = False
|
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
@@ -362,6 +364,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
guidance: torch.Tensor = None,
|
guidance: torch.Tensor = None,
|
||||||
ref_latents=None,
|
ref_latents=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
|
control=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
timestep = timesteps
|
timestep = timesteps
|
||||||
@@ -443,6 +446,13 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
hidden_states = out["img"]
|
hidden_states = out["img"]
|
||||||
encoder_hidden_states = out["txt"]
|
encoder_hidden_states = out["txt"]
|
||||||
|
|
||||||
|
if control is not None: # Controlnet
|
||||||
|
control_i = control.get("input")
|
||||||
|
if i < len(control_i):
|
||||||
|
add = control_i[i]
|
||||||
|
if add is not None:
|
||||||
|
hidden_states += add
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
|
@@ -492,6 +492,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "qwen_image"
|
dit_config["image_model"] = "qwen_image"
|
||||||
|
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||||
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
|
Reference in New Issue
Block a user