diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index c15ab8e40..99843f88d 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -356,6 +356,7 @@ class QwenImageTransformer2DModel(nn.Module): context, attention_mask=None, guidance: torch.Tensor = None, + transformer_options={}, **kwargs ): timestep = timesteps @@ -383,14 +384,26 @@ class QwenImageTransformer2DModel(nn.Module): else self.time_text_embed(timestep, guidance, hidden_states) ) - for block in self.transformer_blocks: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + + for i, block in enumerate(self.transformer_blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"]) + return out + out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states)