P2 of qwen edit model. (#9412)

* P2 of qwen edit model.

* Typo.

* Fix normal qwen.

* Fix.

* Make the TextEncodeQwenImageEdit also set the ref latent.

If you don't want it to set the ref latent and want to use the
ReferenceLatent node with your custom latent instead just disconnect the
VAE.
This commit is contained in:
comfyanonymous
2025-08-18 19:38:34 -07:00
committed by GitHub
parent bd2ab73976
commit 4977f203fa
10 changed files with 565 additions and 15 deletions

View File

@@ -204,17 +204,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
index = 0
pad_extra = 0
embeds_info = []
for o in other_embeds:
emb = o[1]
if torch.is_tensor(emb):
emb = {"type": "embedding", "data": emb}
extra = None
emb_type = emb.get("type", None)
if emb_type == "embedding":
emb = emb.get("data", None)
else:
if hasattr(self.transformer, "preprocess_embed"):
emb = self.transformer.preprocess_embed(emb, device=device)
emb, extra = self.transformer.preprocess_embed(emb, device=device)
else:
emb = None
@@ -229,6 +231,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
index += emb_shape - 1
embeds_info.append({"type": emb_type, "index": ind, "size": emb_shape, "extra": extra})
else:
index += -1
pad_extra += emb_shape
@@ -243,11 +246,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
attention_masks.append(attention_mask)
num_tokens.append(sum(attention_mask))
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
def forward(self, tokens):
device = self.transformer.get_input_embeddings().weight.device
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
attention_mask_model = None
if self.enable_attention_masks:
@@ -258,7 +261,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else:
intermediate_output = self.layer_idx
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, embeds_info=embeds_info)
if self.layer == "last":
z = outputs[0].float()