mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
* P2 of qwen edit model. * Typo. * Fix normal qwen. * Fix. * Make the TextEncodeQwenImageEdit also set the ref latent. If you don't want it to set the ref latent and want to use the ReferenceLatent node with your custom latent instead just disconnect the VAE.
429 lines
16 KiB
Python
429 lines
16 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import Optional, Tuple
|
|
import math
|
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
|
|
|
|
|
def process_qwen2vl_images(
|
|
images: torch.Tensor,
|
|
min_pixels: int = 3136,
|
|
max_pixels: int = 12845056,
|
|
patch_size: int = 14,
|
|
temporal_patch_size: int = 2,
|
|
merge_size: int = 2,
|
|
image_mean: list = None,
|
|
image_std: list = None,
|
|
):
|
|
if image_mean is None:
|
|
image_mean = [0.48145466, 0.4578275, 0.40821073]
|
|
if image_std is None:
|
|
image_std = [0.26862954, 0.26130258, 0.27577711]
|
|
|
|
batch_size, height, width, channels = images.shape
|
|
device = images.device
|
|
# dtype = images.dtype
|
|
|
|
images = images.permute(0, 3, 1, 2)
|
|
|
|
grid_thw_list = []
|
|
img = images[0]
|
|
|
|
factor = patch_size * merge_size
|
|
|
|
h_bar = round(height / factor) * factor
|
|
w_bar = round(width / factor) * factor
|
|
|
|
if h_bar * w_bar > max_pixels:
|
|
beta = math.sqrt((height * width) / max_pixels)
|
|
h_bar = max(factor, math.floor(height / beta / factor) * factor)
|
|
w_bar = max(factor, math.floor(width / beta / factor) * factor)
|
|
elif h_bar * w_bar < min_pixels:
|
|
beta = math.sqrt(min_pixels / (height * width))
|
|
h_bar = math.ceil(height * beta / factor) * factor
|
|
w_bar = math.ceil(width * beta / factor) * factor
|
|
|
|
img_resized = F.interpolate(
|
|
img.unsqueeze(0),
|
|
size=(h_bar, w_bar),
|
|
mode='bilinear',
|
|
align_corners=False
|
|
).squeeze(0)
|
|
|
|
normalized = img_resized.clone()
|
|
for c in range(3):
|
|
normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c]
|
|
|
|
grid_h = h_bar // patch_size
|
|
grid_w = w_bar // patch_size
|
|
grid_thw = torch.tensor([1, grid_h, grid_w], device=device, dtype=torch.long)
|
|
|
|
pixel_values = normalized
|
|
grid_thw_list.append(grid_thw)
|
|
image_grid_thw = torch.stack(grid_thw_list)
|
|
|
|
grid_t = 1
|
|
channel = pixel_values.shape[0]
|
|
pixel_values = pixel_values.unsqueeze(0).repeat(2, 1, 1, 1)
|
|
|
|
patches = pixel_values.reshape(
|
|
grid_t,
|
|
temporal_patch_size,
|
|
channel,
|
|
grid_h // merge_size,
|
|
merge_size,
|
|
patch_size,
|
|
grid_w // merge_size,
|
|
merge_size,
|
|
patch_size,
|
|
)
|
|
|
|
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
|
flatten_patches = patches.reshape(
|
|
grid_t * grid_h * grid_w,
|
|
channel * temporal_patch_size * patch_size * patch_size
|
|
)
|
|
|
|
return flatten_patches, image_grid_thw
|
|
|
|
|
|
class VisionPatchEmbed(nn.Module):
|
|
def __init__(
|
|
self,
|
|
patch_size: int = 14,
|
|
temporal_patch_size: int = 2,
|
|
in_channels: int = 3,
|
|
embed_dim: int = 3584,
|
|
device=None,
|
|
dtype=None,
|
|
ops=None,
|
|
):
|
|
super().__init__()
|
|
self.patch_size = patch_size
|
|
self.temporal_patch_size = temporal_patch_size
|
|
self.in_channels = in_channels
|
|
self.embed_dim = embed_dim
|
|
|
|
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
|
self.proj = ops.Conv3d(
|
|
in_channels,
|
|
embed_dim,
|
|
kernel_size=kernel_size,
|
|
stride=kernel_size,
|
|
bias=False,
|
|
device=device,
|
|
dtype=dtype
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = hidden_states.view(
|
|
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
|
)
|
|
hidden_states = self.proj(hidden_states)
|
|
return hidden_states.view(-1, self.embed_dim)
|
|
|
|
|
|
def rotate_half(x):
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
x2 = x[..., x.shape[-1] // 2 :]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
def apply_rotary_pos_emb_vision(q, k, cos, sin):
|
|
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
return q_embed, k_embed
|
|
|
|
|
|
class VisionRotaryEmbedding(nn.Module):
|
|
def __init__(self, dim: int, theta: float = 10000.0):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.theta = theta
|
|
|
|
def forward(self, seqlen: int, device) -> torch.Tensor:
|
|
inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim))
|
|
seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype)
|
|
freqs = torch.outer(seq, inv_freq)
|
|
return freqs
|
|
|
|
|
|
class PatchMerger(nn.Module):
|
|
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2, device=None, dtype=None, ops=None):
|
|
super().__init__()
|
|
self.hidden_size = context_dim * (spatial_merge_size ** 2)
|
|
self.ln_q = ops.RMSNorm(context_dim, eps=1e-6, device=device, dtype=dtype)
|
|
self.mlp = nn.Sequential(
|
|
ops.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype),
|
|
nn.GELU(),
|
|
ops.Linear(self.hidden_size, dim, device=device, dtype=dtype),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.ln_q(x).reshape(-1, self.hidden_size)
|
|
x = self.mlp(x)
|
|
return x
|
|
|
|
|
|
class VisionAttention(nn.Module):
|
|
def __init__(self, hidden_size: int, num_heads: int, device=None, dtype=None, ops=None):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.num_heads = num_heads
|
|
self.head_dim = hidden_size // num_heads
|
|
self.scaling = self.head_dim ** -0.5
|
|
|
|
self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype)
|
|
self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
cu_seqlens=None,
|
|
optimized_attention=None,
|
|
) -> torch.Tensor:
|
|
if hidden_states.dim() == 2:
|
|
seq_length, _ = hidden_states.shape
|
|
batch_size = 1
|
|
hidden_states = hidden_states.unsqueeze(0)
|
|
else:
|
|
batch_size, seq_length, _ = hidden_states.shape
|
|
|
|
qkv = self.qkv(hidden_states)
|
|
qkv = qkv.reshape(batch_size, seq_length, 3, self.num_heads, self.head_dim)
|
|
query_states, key_states, value_states = qkv.reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
|
|
|
if position_embeddings is not None:
|
|
cos, sin = position_embeddings
|
|
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
|
|
|
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
|
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
|
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
|
|
|
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
splits = [
|
|
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
|
]
|
|
|
|
attn_outputs = [
|
|
optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
|
|
for q, k, v in zip(*splits)
|
|
]
|
|
attn_output = torch.cat(attn_outputs, dim=1)
|
|
attn_output = attn_output.reshape(seq_length, -1)
|
|
attn_output = self.proj(attn_output)
|
|
|
|
return attn_output
|
|
|
|
|
|
class VisionMLP(nn.Module):
|
|
def __init__(self, hidden_size: int, intermediate_size: int, device=None, dtype=None, ops=None):
|
|
super().__init__()
|
|
self.gate_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
|
self.up_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
|
self.down_proj = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
|
|
self.act_fn = nn.SiLU()
|
|
|
|
def forward(self, hidden_state):
|
|
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
|
|
|
|
|
class VisionBlock(nn.Module):
|
|
def __init__(self, hidden_size: int, intermediate_size: int, num_heads: int, device=None, dtype=None, ops=None):
|
|
super().__init__()
|
|
self.norm1 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
|
self.norm2 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
|
self.attn = VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
|
|
self.mlp = VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
cu_seqlens=None,
|
|
optimized_attention=None,
|
|
) -> torch.Tensor:
|
|
residual = hidden_states
|
|
hidden_states = self.norm1(hidden_states)
|
|
hidden_states = self.attn(hidden_states, position_embeddings, cu_seqlens, optimized_attention)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.norm2(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
return hidden_states
|
|
|
|
|
|
class Qwen2VLVisionTransformer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int = 3584,
|
|
output_hidden_size: int = 3584,
|
|
intermediate_size: int = 3420,
|
|
num_heads: int = 16,
|
|
num_layers: int = 32,
|
|
patch_size: int = 14,
|
|
temporal_patch_size: int = 2,
|
|
spatial_merge_size: int = 2,
|
|
window_size: int = 112,
|
|
device=None,
|
|
dtype=None,
|
|
ops=None
|
|
):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.patch_size = patch_size
|
|
self.spatial_merge_size = spatial_merge_size
|
|
self.window_size = window_size
|
|
self.fullatt_block_indexes = [7, 15, 23, 31]
|
|
|
|
self.patch_embed = VisionPatchEmbed(
|
|
patch_size=patch_size,
|
|
temporal_patch_size=temporal_patch_size,
|
|
in_channels=3,
|
|
embed_dim=hidden_size,
|
|
device=device,
|
|
dtype=dtype,
|
|
ops=ops,
|
|
)
|
|
|
|
head_dim = hidden_size // num_heads
|
|
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
|
|
|
self.blocks = nn.ModuleList([
|
|
VisionBlock(hidden_size, intermediate_size, num_heads, device, dtype, ops)
|
|
for _ in range(num_layers)
|
|
])
|
|
|
|
self.merger = PatchMerger(
|
|
dim=output_hidden_size,
|
|
context_dim=hidden_size,
|
|
spatial_merge_size=spatial_merge_size,
|
|
device=device,
|
|
dtype=dtype,
|
|
ops=ops,
|
|
)
|
|
|
|
def get_window_index(self, grid_thw):
|
|
window_index = []
|
|
cu_window_seqlens = [0]
|
|
window_index_id = 0
|
|
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
|
|
|
|
for grid_t, grid_h, grid_w in grid_thw:
|
|
llm_grid_h = grid_h // self.spatial_merge_size
|
|
llm_grid_w = grid_w // self.spatial_merge_size
|
|
|
|
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
|
|
|
|
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
|
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
|
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
|
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
|
|
|
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
|
index_padded = index_padded.reshape(
|
|
grid_t,
|
|
num_windows_h,
|
|
vit_merger_window_size,
|
|
num_windows_w,
|
|
vit_merger_window_size,
|
|
)
|
|
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
|
grid_t,
|
|
num_windows_h * num_windows_w,
|
|
vit_merger_window_size,
|
|
vit_merger_window_size,
|
|
)
|
|
|
|
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
|
index_padded = index_padded.reshape(-1)
|
|
index_new = index_padded[index_padded != -100]
|
|
window_index.append(index_new + window_index_id)
|
|
|
|
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_size * self.spatial_merge_size + cu_window_seqlens[-1]
|
|
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
|
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
|
|
|
window_index = torch.cat(window_index, dim=0)
|
|
return window_index, cu_window_seqlens
|
|
|
|
def get_position_embeddings(self, grid_thw, device):
|
|
pos_ids = []
|
|
|
|
for t, h, w in grid_thw:
|
|
hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w)
|
|
hpos_ids = hpos_ids.reshape(
|
|
h // self.spatial_merge_size,
|
|
self.spatial_merge_size,
|
|
w // self.spatial_merge_size,
|
|
self.spatial_merge_size,
|
|
)
|
|
hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten()
|
|
|
|
wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1)
|
|
wpos_ids = wpos_ids.reshape(
|
|
h // self.spatial_merge_size,
|
|
self.spatial_merge_size,
|
|
w // self.spatial_merge_size,
|
|
self.spatial_merge_size,
|
|
)
|
|
wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten()
|
|
|
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
|
|
|
pos_ids = torch.cat(pos_ids, dim=0)
|
|
max_grid_size = grid_thw[:, 1:].max()
|
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device)
|
|
return rotary_pos_emb_full[pos_ids].flatten(1)
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.Tensor,
|
|
image_grid_thw: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True)
|
|
|
|
hidden_states = self.patch_embed(pixel_values)
|
|
|
|
window_index, cu_window_seqlens = self.get_window_index(image_grid_thw)
|
|
cu_window_seqlens = torch.tensor(cu_window_seqlens, device=hidden_states.device)
|
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
|
|
|
position_embeddings = self.get_position_embeddings(image_grid_thw, hidden_states.device)
|
|
|
|
seq_len, _ = hidden_states.size()
|
|
spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
|
|
|
|
hidden_states = hidden_states.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1)
|
|
hidden_states = hidden_states[window_index, :, :]
|
|
hidden_states = hidden_states.reshape(seq_len, -1)
|
|
|
|
position_embeddings = position_embeddings.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1)
|
|
position_embeddings = position_embeddings[window_index, :, :]
|
|
position_embeddings = position_embeddings.reshape(seq_len, -1)
|
|
position_embeddings = torch.cat((position_embeddings, position_embeddings), dim=-1)
|
|
position_embeddings = (position_embeddings.cos(), position_embeddings.sin())
|
|
|
|
cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum(
|
|
dim=0,
|
|
dtype=torch.int32,
|
|
)
|
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
|
|
|
for i, block in enumerate(self.blocks):
|
|
if i in self.fullatt_block_indexes:
|
|
cu_seqlens_now = cu_seqlens
|
|
else:
|
|
cu_seqlens_now = cu_window_seqlens
|
|
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
|
|
|
|
hidden_states = self.merger(hidden_states)
|
|
return hidden_states
|