From 0f2b8525bcafe213e8421a49564a90f926e81f2e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 16 Aug 2025 14:51:28 -0700 Subject: [PATCH] Qwen image model refactor. (#9375) --- comfy/ldm/qwen_image/model.py | 36 +++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 99843f88d..40d8fd979 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -333,21 +333,25 @@ class QwenImageTransformer2DModel(nn.Module): self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) self.gradient_checkpointing = False - def pos_embeds(self, x, context): + def process_img(self, x, index=0, h_offset=0, w_offset=0): bs, c, t, h, w = x.shape patch_size = self.patch_size + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) + hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) h_len = ((h + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + h_offset = ((h_offset + (patch_size // 2)) // patch_size) + w_offset = ((w_offset + (patch_size // 2)) // patch_size) - txt_start = round(max(h_len, w_len)) - txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3) - ids = torch.cat((txt_ids, img_ids), dim=1) - return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 0] = img_ids[:, :, 1] + index + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape def forward( self, @@ -363,13 +367,13 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states = context encoder_hidden_states_mask = attention_mask - image_rotary_emb = self.pos_embeds(x, context) + hidden_states, img_ids, orig_shape = self.process_img(x) + num_embeds = hidden_states.shape[1] - hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) - orig_shape = hidden_states.shape - hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) - hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) - hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size))) + txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) hidden_states = self.img_in(hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states) @@ -408,6 +412,6 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) + hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]