Rope fix for qwen vl. (#9435)

This commit is contained in:
comfyanonymous
2025-08-19 17:47:42 -07:00
committed by GitHub
parent bddd69618b
commit dfa791eb4b

View File

@@ -27,6 +27,7 @@ class Llama2Config:
rms_norm_add = False rms_norm_add = False
mlp_activation = "silu" mlp_activation = "silu"
qkv_bias = False qkv_bias = False
rope_dims = None
@dataclass @dataclass
class Qwen25_3BConfig: class Qwen25_3BConfig:
@@ -44,6 +45,7 @@ class Qwen25_3BConfig:
rms_norm_add = False rms_norm_add = False
mlp_activation = "silu" mlp_activation = "silu"
qkv_bias = True qkv_bias = True
rope_dims = None
@dataclass @dataclass
class Qwen25_7BVLI_Config: class Qwen25_7BVLI_Config:
@@ -61,6 +63,7 @@ class Qwen25_7BVLI_Config:
rms_norm_add = False rms_norm_add = False
mlp_activation = "silu" mlp_activation = "silu"
qkv_bias = True qkv_bias = True
rope_dims = [16, 24, 24]
@dataclass @dataclass
class Gemma2_2B_Config: class Gemma2_2B_Config:
@@ -78,6 +81,7 @@ class Gemma2_2B_Config:
rms_norm_add = True rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh" mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False qkv_bias = False
rope_dims = None
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -102,7 +106,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, position_ids, theta, device=None): def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float() theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim)) inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
@@ -112,12 +116,20 @@ def precompute_freqs_cis(head_dim, position_ids, theta, device=None):
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() cos = emb.cos()
sin = emb.sin() sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
return (cos, sin) return (cos, sin)
def apply_rope(xq, xk, freqs_cis): def apply_rope(xq, xk, freqs_cis):
cos = freqs_cis[0].unsqueeze(1) cos = freqs_cis[0]
sin = freqs_cis[1].unsqueeze(1) sin = freqs_cis[1]
q_embed = (xq * cos) + (rotate_half(xq) * sin) q_embed = (xq * cos) + (rotate_half(xq) * sin)
k_embed = (xk * cos) + (rotate_half(xk) * sin) k_embed = (xk * cos) + (rotate_half(xk) * sin)
return q_embed, k_embed return q_embed, k_embed
@@ -292,6 +304,7 @@ class Llama2_(nn.Module):
freqs_cis = precompute_freqs_cis(self.config.head_dim, freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids, position_ids,
self.config.rope_theta, self.config.rope_theta,
self.config.rope_dims,
device=x.device) device=x.device)
mask = None mask = None