From ed43784b0d04e5b8e8ff2c057fa84b9c5132aaf2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 17 Aug 2025 13:45:39 -0700 Subject: [PATCH] WIP Qwen edit model: The diffusion model part. (#9383) --- comfy/ldm/qwen_image/model.py | 26 ++++++++++++++++++++++++++ comfy/model_base.py | 10 ++++++++++ 2 files changed, 36 insertions(+) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 40d8fd979..a3c726299 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -360,6 +360,7 @@ class QwenImageTransformer2DModel(nn.Module): context, attention_mask=None, guidance: torch.Tensor = None, + ref_latents=None, transformer_options={}, **kwargs ): @@ -370,6 +371,31 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states, img_ids, orig_shape = self.process_img(x) num_embeds = hidden_states.shape[1] + if ref_latents is not None: + h = 0 + w = 0 + index = 0 + index_ref_method = kwargs.get("ref_latents_method", "index") == "index" + for ref in ref_latents: + if index_ref_method: + index += 1 + h_offset = 0 + w_offset = 0 + else: + index = 1 + h_offset = 0 + w_offset = 0 + if ref.shape[-2] + h > ref.shape[-1] + w: + w_offset = w + else: + h_offset = h + h = max(h, ref.shape[-2] + h_offset) + w = max(w, ref.shape[-1] + w_offset) + + kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + hidden_states = torch.cat([hidden_states, kontext], dim=1) + img_ids = torch.cat([img_ids, kontext_ids], dim=1) + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size))) txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) diff --git a/comfy/model_base.py b/comfy/model_base.py index bf874b875..15bd7abef 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1331,4 +1331,14 @@ class QwenImage(BaseModel): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) + + ref_latents_method = kwargs.get("reference_latents_method", None) + if ref_latents_method is not None: + out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) return out