diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 9d90d5a61..4c976058f 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -27,6 +27,7 @@ class Llama2Config: rms_norm_add = False mlp_activation = "silu" qkv_bias = False + rope_dims = None @dataclass class Qwen25_3BConfig: @@ -44,6 +45,7 @@ class Qwen25_3BConfig: rms_norm_add = False mlp_activation = "silu" qkv_bias = True + rope_dims = None @dataclass class Qwen25_7BVLI_Config: @@ -61,6 +63,7 @@ class Qwen25_7BVLI_Config: rms_norm_add = False mlp_activation = "silu" qkv_bias = True + rope_dims = [16, 24, 24] @dataclass class Gemma2_2B_Config: @@ -78,6 +81,7 @@ class Gemma2_2B_Config: rms_norm_add = True mlp_activation = "gelu_pytorch_tanh" qkv_bias = False + rope_dims = None class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): @@ -102,7 +106,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def precompute_freqs_cis(head_dim, position_ids, theta, device=None): +def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None): theta_numerator = torch.arange(0, head_dim, 2, device=device).float() inv_freq = 1.0 / (theta ** (theta_numerator / head_dim)) @@ -112,12 +116,20 @@ def precompute_freqs_cis(head_dim, position_ids, theta, device=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + if rope_dims is not None and position_ids.shape[0] > 1: + mrope_section = rope_dims * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) + else: + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + return (cos, sin) def apply_rope(xq, xk, freqs_cis): - cos = freqs_cis[0].unsqueeze(1) - sin = freqs_cis[1].unsqueeze(1) + cos = freqs_cis[0] + sin = freqs_cis[1] q_embed = (xq * cos) + (rotate_half(xq) * sin) k_embed = (xk * cos) + (rotate_half(xk) * sin) return q_embed, k_embed @@ -292,6 +304,7 @@ class Llama2_(nn.Module): freqs_cis = precompute_freqs_cis(self.config.head_dim, position_ids, self.config.rope_theta, + self.config.rope_dims, device=x.device) mask = None