import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple import math from comfy.ldm.modules.attention import optimized_attention_for_device def process_qwen2vl_images( images: torch.Tensor, min_pixels: int = 3136, max_pixels: int = 12845056, patch_size: int = 14, temporal_patch_size: int = 2, merge_size: int = 2, image_mean: list = None, image_std: list = None, ): if image_mean is None: image_mean = [0.48145466, 0.4578275, 0.40821073] if image_std is None: image_std = [0.26862954, 0.26130258, 0.27577711] batch_size, height, width, channels = images.shape device = images.device # dtype = images.dtype images = images.permute(0, 3, 1, 2) grid_thw_list = [] img = images[0] factor = patch_size * merge_size h_bar = round(height / factor) * factor w_bar = round(width / factor) * factor if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = max(factor, math.floor(height / beta / factor) * factor) w_bar = max(factor, math.floor(width / beta / factor) * factor) elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = math.ceil(height * beta / factor) * factor w_bar = math.ceil(width * beta / factor) * factor img_resized = F.interpolate( img.unsqueeze(0), size=(h_bar, w_bar), mode='bilinear', align_corners=False ).squeeze(0) normalized = img_resized.clone() for c in range(3): normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c] grid_h = h_bar // patch_size grid_w = w_bar // patch_size grid_thw = torch.tensor([1, grid_h, grid_w], device=device, dtype=torch.long) pixel_values = normalized grid_thw_list.append(grid_thw) image_grid_thw = torch.stack(grid_thw_list) grid_t = 1 channel = pixel_values.shape[0] pixel_values = pixel_values.unsqueeze(0).repeat(2, 1, 1, 1) patches = pixel_values.reshape( grid_t, temporal_patch_size, channel, grid_h // merge_size, merge_size, patch_size, grid_w // merge_size, merge_size, patch_size, ) patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) flatten_patches = patches.reshape( grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size ) return flatten_patches, image_grid_thw class VisionPatchEmbed(nn.Module): def __init__( self, patch_size: int = 14, temporal_patch_size: int = 2, in_channels: int = 3, embed_dim: int = 3584, device=None, dtype=None, ops=None, ): super().__init__() self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.in_channels = in_channels self.embed_dim = embed_dim kernel_size = [temporal_patch_size, patch_size, patch_size] self.proj = ops.Conv3d( in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False, device=device, dtype=dtype ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size ) hidden_states = self.proj(hidden_states) return hidden_states.view(-1, self.embed_dim) def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb_vision(q, k, cos, sin): cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class VisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0): super().__init__() self.dim = dim self.theta = theta def forward(self, seqlen: int, device) -> torch.Tensor: inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim)) seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype) freqs = torch.outer(seq, inv_freq) return freqs class PatchMerger(nn.Module): def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2, device=None, dtype=None, ops=None): super().__init__() self.hidden_size = context_dim * (spatial_merge_size ** 2) self.ln_q = ops.RMSNorm(context_dim, eps=1e-6, device=device, dtype=dtype) self.mlp = nn.Sequential( ops.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype), nn.GELU(), ops.Linear(self.hidden_size, dim, device=device, dtype=dtype), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.ln_q(x).reshape(-1, self.hidden_size) x = self.mlp(x) return x class VisionAttention(nn.Module): def __init__(self, hidden_size: int, num_heads: int, device=None, dtype=None, ops=None): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.scaling = self.head_dim ** -0.5 self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype) self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype) def forward( self, hidden_states: torch.Tensor, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, cu_seqlens=None, optimized_attention=None, ) -> torch.Tensor: if hidden_states.dim() == 2: seq_length, _ = hidden_states.shape batch_size = 1 hidden_states = hidden_states.unsqueeze(0) else: batch_size, seq_length, _ = hidden_states.shape qkv = self.qkv(hidden_states) qkv = qkv.reshape(batch_size, seq_length, 3, self.num_heads, self.head_dim) query_states, key_states, value_states = qkv.reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) if position_embeddings is not None: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) lengths = cu_seqlens[1:] - cu_seqlens[:-1] splits = [ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) ] attn_outputs = [ optimized_attention(q, k, v, self.num_heads, skip_reshape=True) for q, k, v in zip(*splits) ] attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output class VisionMLP(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int, device=None, dtype=None, ops=None): super().__init__() self.gate_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype) self.up_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype) self.down_proj = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype) self.act_fn = nn.SiLU() def forward(self, hidden_state): return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) class VisionBlock(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int, num_heads: int, device=None, dtype=None, ops=None): super().__init__() self.norm1 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) self.norm2 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) self.attn = VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops) self.mlp = VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) def forward( self, hidden_states: torch.Tensor, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, cu_seqlens=None, optimized_attention=None, ) -> torch.Tensor: residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = self.attn(hidden_states, position_embeddings, cu_seqlens, optimized_attention) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class Qwen2VLVisionTransformer(nn.Module): def __init__( self, hidden_size: int = 3584, output_hidden_size: int = 3584, intermediate_size: int = 3420, num_heads: int = 16, num_layers: int = 32, patch_size: int = 14, temporal_patch_size: int = 2, spatial_merge_size: int = 2, window_size: int = 112, device=None, dtype=None, ops=None ): super().__init__() self.hidden_size = hidden_size self.patch_size = patch_size self.spatial_merge_size = spatial_merge_size self.window_size = window_size self.fullatt_block_indexes = [7, 15, 23, 31] self.patch_embed = VisionPatchEmbed( patch_size=patch_size, temporal_patch_size=temporal_patch_size, in_channels=3, embed_dim=hidden_size, device=device, dtype=dtype, ops=ops, ) head_dim = hidden_size // num_heads self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList([ VisionBlock(hidden_size, intermediate_size, num_heads, device, dtype, ops) for _ in range(num_layers) ]) self.merger = PatchMerger( dim=output_hidden_size, context_dim=hidden_size, spatial_merge_size=spatial_merge_size, device=device, dtype=dtype, ops=ops, ) def get_window_index(self, grid_thw): window_index = [] cu_window_seqlens = [0] window_index_id = 0 vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size for grid_t, grid_h, grid_w in grid_thw: llm_grid_h = grid_h // self.spatial_merge_size llm_grid_w = grid_w // self.spatial_merge_size index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) index_padded = index_padded.reshape( grid_t, num_windows_h, vit_merger_window_size, num_windows_w, vit_merger_window_size, ) index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( grid_t, num_windows_h * num_windows_w, vit_merger_window_size, vit_merger_window_size, ) seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) index_padded = index_padded.reshape(-1) index_new = index_padded[index_padded != -100] window_index.append(index_new + window_index_id) cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_size * self.spatial_merge_size + cu_window_seqlens[-1] cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() window_index = torch.cat(window_index, dim=0) return window_index, cu_window_seqlens def get_position_embeddings(self, grid_thw, device): pos_ids = [] for t, h, w in grid_thw: hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w) hpos_ids = hpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, ) hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten() wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1) wpos_ids = wpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, ) wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten() pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device) return rotary_pos_emb_full[pos_ids].flatten(1) def forward( self, pixel_values: torch.Tensor, image_grid_thw: Optional[torch.Tensor] = None, ) -> torch.Tensor: optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True) hidden_states = self.patch_embed(pixel_values) window_index, cu_window_seqlens = self.get_window_index(image_grid_thw) cu_window_seqlens = torch.tensor(cu_window_seqlens, device=hidden_states.device) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) position_embeddings = self.get_position_embeddings(image_grid_thw, hidden_states.device) seq_len, _ = hidden_states.size() spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size hidden_states = hidden_states.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) position_embeddings = position_embeddings.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1) position_embeddings = position_embeddings[window_index, :, :] position_embeddings = position_embeddings.reshape(seq_len, -1) position_embeddings = torch.cat((position_embeddings, position_embeddings), dim=-1) position_embeddings = (position_embeddings.cos(), position_embeddings.sin()) cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( dim=0, dtype=torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for i, block in enumerate(self.blocks): if i in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention) hidden_states = self.merger(hidden_states) return hidden_states