Add explicit casting in apply_rope for Qwen VL (#9759)

This commit is contained in:
contentis
2025-09-08 21:08:18 +02:00
committed by GitHub
parent bd1d9bcd5f
commit 97652d26b8

View File

@@ -128,11 +128,12 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=N
def apply_rope(xq, xk, freqs_cis): def apply_rope(xq, xk, freqs_cis):
org_dtype = xq.dtype
cos = freqs_cis[0] cos = freqs_cis[0]
sin = freqs_cis[1] sin = freqs_cis[1]
q_embed = (xq * cos) + (rotate_half(xq) * sin) q_embed = (xq * cos) + (rotate_half(xq) * sin)
k_embed = (xk * cos) + (rotate_half(xk) * sin) k_embed = (xk * cos) + (rotate_half(xk) * sin)
return q_embed, k_embed return q_embed.to(org_dtype), k_embed.to(org_dtype)
class Attention(nn.Module): class Attention(nn.Module):