From 4977f203fa8e9e3ab22884c8ace8f9b540d48952 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 18 Aug 2025 19:38:34 -0700 Subject: [PATCH] P2 of qwen edit model. (#9412) * P2 of qwen edit model. * Typo. * Fix normal qwen. * Fix. * Make the TextEncodeQwenImageEdit also set the ref latent. If you don't want it to set the ref latent and want to use the ReferenceLatent node with your custom latent instead just disconnect the VAE. --- comfy/clip_model.py | 2 +- comfy/model_base.py | 8 + comfy/sd1_clip.py | 11 +- comfy/text_encoders/bert.py | 2 +- comfy/text_encoders/llama.py | 43 ++- comfy/text_encoders/qwen_image.py | 20 +- comfy/text_encoders/qwen_vl.py | 428 ++++++++++++++++++++++++++++++ comfy/text_encoders/t5.py | 2 +- comfy_extras/nodes_qwen.py | 63 +++++ nodes.py | 1 + 10 files changed, 565 insertions(+), 15 deletions(-) create mode 100644 comfy/text_encoders/qwen_vl.py create mode 100644 comfy_extras/nodes_qwen.py diff --git a/comfy/clip_model.py b/comfy/clip_model.py index c8294d483..7e47d8a55 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module): self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) - def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32): + def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]): if embeds is not None: x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device) else: diff --git a/comfy/model_base.py b/comfy/model_base.py index 15bd7abef..6c861b15e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1325,6 +1325,7 @@ class Omnigen2(BaseModel): class QwenImage(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel) + self.memory_usage_factor_conds = ("ref_latents",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1342,3 +1343,10 @@ class QwenImage(BaseModel): if ref_latents_method is not None: out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + return out diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index ade340fd1..1e8adbe69 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -204,17 +204,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32) index = 0 pad_extra = 0 + embeds_info = [] for o in other_embeds: emb = o[1] if torch.is_tensor(emb): emb = {"type": "embedding", "data": emb} + extra = None emb_type = emb.get("type", None) if emb_type == "embedding": emb = emb.get("data", None) else: if hasattr(self.transformer, "preprocess_embed"): - emb = self.transformer.preprocess_embed(emb, device=device) + emb, extra = self.transformer.preprocess_embed(emb, device=device) else: emb = None @@ -229,6 +231,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1) attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:] index += emb_shape - 1 + embeds_info.append({"type": emb_type, "index": ind, "size": emb_shape, "extra": extra}) else: index += -1 pad_extra += emb_shape @@ -243,11 +246,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): attention_masks.append(attention_mask) num_tokens.append(sum(attention_mask)) - return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens + return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info def forward(self, tokens): device = self.transformer.get_input_embeddings().weight.device - embeds, attention_mask, num_tokens = self.process_tokens(tokens, device) + embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device) attention_mask_model = None if self.enable_attention_masks: @@ -258,7 +261,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): else: intermediate_output = self.layer_idx - outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) + outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, embeds_info=embeds_info) if self.layer == "last": z = outputs[0].float() diff --git a/comfy/text_encoders/bert.py b/comfy/text_encoders/bert.py index 551b03162..ed4638a9a 100644 --- a/comfy/text_encoders/bert.py +++ b/comfy/text_encoders/bert.py @@ -116,7 +116,7 @@ class BertModel_(torch.nn.Module): self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations) self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations) - def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype) mask = None if attention_mask is not None: diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 1da6a0c94..9d90d5a61 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -2,12 +2,14 @@ import torch import torch.nn as nn from dataclasses import dataclass from typing import Optional, Any +import math from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management import comfy.ldm.common_dit import comfy.model_management +from . import qwen_vl @dataclass class Llama2Config: @@ -100,12 +102,10 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def precompute_freqs_cis(head_dim, seq_len, theta, device=None): +def precompute_freqs_cis(head_dim, position_ids, theta, device=None): theta_numerator = torch.arange(0, head_dim, 2, device=device).float() inv_freq = 1.0 / (theta ** (theta_numerator / head_dim)) - position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0) - inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) @@ -277,7 +277,7 @@ class Llama2_(nn.Module): self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) - def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]): if embeds is not None: x = embeds else: @@ -286,8 +286,11 @@ class Llama2_(nn.Module): if self.normalize_in: x *= self.config.hidden_size ** 0.5 + if position_ids is None: + position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0) + freqs_cis = precompute_freqs_cis(self.config.head_dim, - x.shape[1], + position_ids, self.config.rope_theta, device=x.device) @@ -372,8 +375,38 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module): self.num_layers = config.num_hidden_layers self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations) self.dtype = dtype + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image, grid = qwen_vl.process_qwen2vl_images(embed["data"]) + return self.visual(image.to(device, dtype=torch.float32), grid), grid + return None, None + + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): + grid = None + for e in embeds_info: + if e.get("type") == "image": + grid = e.get("extra", None) + position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) + start = e.get("index") + position_ids[:, :start] = torch.arange(0, start, device=embeds.device) + end = e.get("size") + start + len_max = int(grid.max()) // 2 + start_next = len_max + start + position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device) + position_ids[0, start:end] = start + max_d = int(grid[0][1]) // 2 + position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] + max_d = int(grid[0][2]) // 2 + position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start] + + if grid is None: + position_ids = None + + return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids) + class Gemma2_2B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py index ce5c98097..f07318d6c 100644 --- a/comfy/text_encoders/qwen_image.py +++ b/comfy/text_encoders/qwen_image.py @@ -15,13 +15,27 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer) self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image \\(color, shape, size, texture, objects, background\\), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" - def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs): if llama_template is None: - llama_text = self.llama_template.format(text) + if len(images) > 0: + llama_text = self.llama_template_images.format(text) + else: + llama_text = self.llama_template.format(text) else: llama_text = llama_template.format(text) - return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs) + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs) + key_name = next(iter(tokens)) + embed_count = 0 + qwen_tokens = tokens[key_name] + for r in qwen_tokens: + for i in range(len(r)): + if r[i][0] == 151655: + if len(images) > embed_count: + r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:] + embed_count += 1 + return tokens class Qwen25_7BVLIModel(sd1_clip.SDClipModel): diff --git a/comfy/text_encoders/qwen_vl.py b/comfy/text_encoders/qwen_vl.py new file mode 100644 index 000000000..3b18ce730 --- /dev/null +++ b/comfy/text_encoders/qwen_vl.py @@ -0,0 +1,428 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Tuple +import math +from comfy.ldm.modules.attention import optimized_attention_for_device + + +def process_qwen2vl_images( + images: torch.Tensor, + min_pixels: int = 3136, + max_pixels: int = 12845056, + patch_size: int = 14, + temporal_patch_size: int = 2, + merge_size: int = 2, + image_mean: list = None, + image_std: list = None, +): + if image_mean is None: + image_mean = [0.48145466, 0.4578275, 0.40821073] + if image_std is None: + image_std = [0.26862954, 0.26130258, 0.27577711] + + batch_size, height, width, channels = images.shape + device = images.device + # dtype = images.dtype + + images = images.permute(0, 3, 1, 2) + + grid_thw_list = [] + img = images[0] + + factor = patch_size * merge_size + + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + + img_resized = F.interpolate( + img.unsqueeze(0), + size=(h_bar, w_bar), + mode='bilinear', + align_corners=False + ).squeeze(0) + + normalized = img_resized.clone() + for c in range(3): + normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c] + + grid_h = h_bar // patch_size + grid_w = w_bar // patch_size + grid_thw = torch.tensor([1, grid_h, grid_w], device=device, dtype=torch.long) + + pixel_values = normalized + grid_thw_list.append(grid_thw) + image_grid_thw = torch.stack(grid_thw_list) + + grid_t = 1 + channel = pixel_values.shape[0] + pixel_values = pixel_values.unsqueeze(0).repeat(2, 1, 1, 1) + + patches = pixel_values.reshape( + grid_t, + temporal_patch_size, + channel, + grid_h // merge_size, + merge_size, + patch_size, + grid_w // merge_size, + merge_size, + patch_size, + ) + + patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size + ) + + return flatten_patches, image_grid_thw + + +class VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 3584, + device=None, + dtype=None, + ops=None, + ): + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = ops.Conv3d( + in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + device=device, + dtype=dtype + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states) + return hidden_states.view(-1, self.embed_dim) + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision(q, k, cos, sin): + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0): + super().__init__() + self.dim = dim + self.theta = theta + + def forward(self, seqlen: int, device) -> torch.Tensor: + inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim)) + seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.outer(seq, inv_freq) + return freqs + + +class PatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2, device=None, dtype=None, ops=None): + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size ** 2) + self.ln_q = ops.RMSNorm(context_dim, eps=1e-6, device=device, dtype=dtype) + self.mlp = nn.Sequential( + ops.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype), + nn.GELU(), + ops.Linear(self.hidden_size, dim, device=device, dtype=dtype), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ln_q(x).reshape(-1, self.hidden_size) + x = self.mlp(x) + return x + + +class VisionAttention(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, device=None, dtype=None, ops=None): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scaling = self.head_dim ** -0.5 + + self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype) + self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cu_seqlens=None, + optimized_attention=None, + ) -> torch.Tensor: + if hidden_states.dim() == 2: + seq_length, _ = hidden_states.shape + batch_size = 1 + hidden_states = hidden_states.unsqueeze(0) + else: + batch_size, seq_length, _ = hidden_states.shape + + qkv = self.qkv(hidden_states) + qkv = qkv.reshape(batch_size, seq_length, 3, self.num_heads, self.head_dim) + query_states, key_states, value_states = qkv.reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + optimized_attention(q, k, v, self.num_heads, skip_reshape=True) + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + + return attn_output + + +class VisionMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, device=None, dtype=None, ops=None): + super().__init__() + self.gate_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype) + self.up_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype) + self.down_proj = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype) + self.act_fn = nn.SiLU() + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class VisionBlock(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, num_heads: int, device=None, dtype=None, ops=None): + super().__init__() + self.norm1 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) + self.norm2 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) + self.attn = VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops) + self.mlp = VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cu_seqlens=None, + optimized_attention=None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = self.attn(hidden_states, position_embeddings, cu_seqlens, optimized_attention) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen2VLVisionTransformer(nn.Module): + def __init__( + self, + hidden_size: int = 3584, + output_hidden_size: int = 3584, + intermediate_size: int = 3420, + num_heads: int = 16, + num_layers: int = 32, + patch_size: int = 14, + temporal_patch_size: int = 2, + spatial_merge_size: int = 2, + window_size: int = 112, + device=None, + dtype=None, + ops=None + ): + super().__init__() + self.hidden_size = hidden_size + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.window_size = window_size + self.fullatt_block_indexes = [7, 15, 23, 31] + + self.patch_embed = VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=3, + embed_dim=hidden_size, + device=device, + dtype=dtype, + ops=ops, + ) + + head_dim = hidden_size // num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + VisionBlock(hidden_size, intermediate_size, num_heads, device, dtype, ops) + for _ in range(num_layers) + ]) + + self.merger = PatchMerger( + dim=output_hidden_size, + context_dim=hidden_size, + spatial_merge_size=spatial_merge_size, + device=device, + dtype=dtype, + ops=ops, + ) + + def get_window_index(self, grid_thw): + window_index = [] + cu_window_seqlens = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h = grid_h // self.spatial_merge_size + llm_grid_w = grid_w // self.spatial_merge_size + + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_size * self.spatial_merge_size + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + + window_index = torch.cat(window_index, dim=0) + return window_index, cu_window_seqlens + + def get_position_embeddings(self, grid_thw, device): + pos_ids = [] + + for t, h, w in grid_thw: + hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten() + + wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten() + + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device) + return rotary_pos_emb_full[pos_ids].flatten(1) + + def forward( + self, + pixel_values: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True) + + hidden_states = self.patch_embed(pixel_values) + + window_index, cu_window_seqlens = self.get_window_index(image_grid_thw) + cu_window_seqlens = torch.tensor(cu_window_seqlens, device=hidden_states.device) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + position_embeddings = self.get_position_embeddings(image_grid_thw, hidden_states.device) + + seq_len, _ = hidden_states.size() + spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + hidden_states = hidden_states.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + + position_embeddings = position_embeddings.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1) + position_embeddings = position_embeddings[window_index, :, :] + position_embeddings = position_embeddings.reshape(seq_len, -1) + position_embeddings = torch.cat((position_embeddings, position_embeddings), dim=-1) + position_embeddings = (position_embeddings.cos(), position_embeddings.sin()) + + cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( + dim=0, + dtype=torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for i, block in enumerate(self.blocks): + if i in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention) + + hidden_states = self.merger(hidden_states) + return hidden_states diff --git a/comfy/text_encoders/t5.py b/comfy/text_encoders/t5.py index 36bf35309..e8588992a 100644 --- a/comfy/text_encoders/t5.py +++ b/comfy/text_encoders/t5.py @@ -199,7 +199,7 @@ class T5Stack(torch.nn.Module): self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations) # self.dropout = nn.Dropout(config.dropout_rate) - def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py new file mode 100644 index 000000000..b5088fae2 --- /dev/null +++ b/comfy_extras/nodes_qwen.py @@ -0,0 +1,63 @@ +import node_helpers +import comfy.utils + +PREFERRED_QWENIMAGE_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + + +class TextEncodeQwenImageEdit: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip": ("CLIP", ), + "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), + }, + "optional": {"vae": ("VAE", ), + "image": ("IMAGE", ),}} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, prompt, vae=None, image=None): + ref_latent = None + if image is None: + images = [] + else: + images = [image] + if vae is not None: + width = image.shape[2] + height = image.shape[1] + aspect_ratio = width / height + _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS) + image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) + ref_latent = vae.encode(image[:, :, :, :3]) + + tokens = clip.tokenize(prompt, images=images) + conditioning = clip.encode_from_tokens_scheduled(tokens) + if ref_latent is not None: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True) + return (conditioning, ) + + +NODE_CLASS_MAPPINGS = { + "TextEncodeQwenImageEdit": TextEncodeQwenImageEdit, +} diff --git a/nodes.py b/nodes.py index 860a236aa..b3fa9c51a 100644 --- a/nodes.py +++ b/nodes.py @@ -2321,6 +2321,7 @@ async def init_builtin_extra_nodes(): "nodes_edit_model.py", "nodes_tcfg.py", "nodes_context_windows.py", + "nodes_qwen.py", ] import_failed = []