Make it easier to implement future qwen controlnets. (#9485)

This commit is contained in:
comfyanonymous
2025-08-21 20:18:04 -07:00
committed by GitHub
parent 7ed73d12d1
commit f7bd5e58dd
3 changed files with 17 additions and 5 deletions

View File

@@ -293,6 +293,7 @@ class QwenImageTransformer2DModel(nn.Module):
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
image_model=None,
final_layer=True,
dtype=None,
device=None,
operations=None,
@@ -300,6 +301,7 @@ class QwenImageTransformer2DModel(nn.Module):
super().__init__()
self.dtype = dtype
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
@@ -329,9 +331,9 @@ class QwenImageTransformer2DModel(nn.Module):
for _ in range(num_layers)
])
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
self.gradient_checkpointing = False
if final_layer:
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, t, h, w = x.shape
@@ -362,6 +364,7 @@ class QwenImageTransformer2DModel(nn.Module):
guidance: torch.Tensor = None,
ref_latents=None,
transformer_options={},
control=None,
**kwargs
):
timestep = timesteps
@@ -443,6 +446,13 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
hidden_states += add
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)