mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-10-25 07:54:30 +00:00
@@ -365,8 +365,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["in_channels"] = 16
|
||||
dit_config["dim"] = 2304
|
||||
dit_config["cap_feat_dim"] = 2304
|
||||
dit_config["n_layers"] = 26
|
||||
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
|
||||
dit_config["n_heads"] = 24
|
||||
dit_config["n_kv_heads"] = 8
|
||||
dit_config["qk_norm"] = True
|
||||
|
||||
@@ -890,6 +890,7 @@ class TEModel(Enum):
|
||||
QWEN25_3B = 10
|
||||
QWEN25_7B = 11
|
||||
BYT5_SMALL_GLYPH = 12
|
||||
GEMMA_3_4B = 13
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
@@ -912,6 +913,8 @@ def detect_te_model(sd):
|
||||
return TEModel.BYT5_SMALL_GLYPH
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_3_4B
|
||||
return TEModel.GEMMA_2_2B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||
@@ -1016,6 +1019,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.GEMMA_3_4B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.LLAMA3_8:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
||||
|
||||
@@ -3,6 +3,7 @@ import torch.nn as nn
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
import math
|
||||
import logging
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.model_management
|
||||
@@ -28,6 +29,9 @@ class Llama2Config:
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
|
||||
@dataclass
|
||||
class Qwen25_3BConfig:
|
||||
@@ -46,6 +50,9 @@ class Qwen25_3BConfig:
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = True
|
||||
rope_dims = None
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
|
||||
@dataclass
|
||||
class Qwen25_7BVLI_Config:
|
||||
@@ -64,6 +71,9 @@ class Qwen25_7BVLI_Config:
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = True
|
||||
rope_dims = [16, 24, 24]
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
@@ -82,6 +92,32 @@ class Gemma2_2B_Config:
|
||||
mlp_activation = "gelu_pytorch_tanh"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
sliding_attention = None
|
||||
rope_scale = None
|
||||
|
||||
@dataclass
|
||||
class Gemma3_4B_Config:
|
||||
vocab_size: int = 262208
|
||||
hidden_size: int = 2560
|
||||
intermediate_size: int = 10240
|
||||
num_hidden_layers: int = 34
|
||||
num_attention_heads: int = 8
|
||||
num_key_value_heads: int = 4
|
||||
max_position_embeddings: int = 131072
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta = [10000.0, 1000000.0]
|
||||
transformer_type: str = "gemma3"
|
||||
head_dim = 256
|
||||
rms_norm_add = True
|
||||
mlp_activation = "gelu_pytorch_tanh"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
sliding_attention = [False, False, False, False, False, 1024]
|
||||
rope_scale = [1.0, 8.0]
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||
@@ -106,25 +142,40 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None):
|
||||
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
|
||||
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
|
||||
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
||||
if not isinstance(theta, list):
|
||||
theta = [theta]
|
||||
|
||||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
if rope_dims is not None and position_ids.shape[0] > 1:
|
||||
mrope_section = rope_dims * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
|
||||
else:
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
out = []
|
||||
for index, t in enumerate(theta):
|
||||
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
|
||||
inv_freq = 1.0 / (t ** (theta_numerator / head_dim))
|
||||
|
||||
return (cos, sin)
|
||||
if rope_scale is not None:
|
||||
if isinstance(rope_scale, list):
|
||||
inv_freq /= rope_scale[index]
|
||||
else:
|
||||
inv_freq /= rope_scale
|
||||
|
||||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
if rope_dims is not None and position_ids.shape[0] > 1:
|
||||
mrope_section = rope_dims * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
|
||||
else:
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
out.append((cos, sin))
|
||||
|
||||
if len(out) == 1:
|
||||
return out[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
@@ -152,6 +203,14 @@ class Attention(nn.Module):
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
self.q_norm = None
|
||||
self.k_norm = None
|
||||
|
||||
if config.q_norm == "gemma3":
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
if config.k_norm == "gemma3":
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -168,6 +227,11 @@ class Attention(nn.Module):
|
||||
xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if self.q_norm is not None:
|
||||
xq = self.q_norm(xq)
|
||||
if self.k_norm is not None:
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
@@ -192,7 +256,7 @@ class MLP(nn.Module):
|
||||
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||
@@ -226,7 +290,7 @@ class TransformerBlock(nn.Module):
|
||||
return x
|
||||
|
||||
class TransformerBlockGemma2(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||
@@ -235,6 +299,13 @@ class TransformerBlockGemma2(nn.Module):
|
||||
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
|
||||
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
|
||||
else:
|
||||
self.sliding_attention = False
|
||||
|
||||
self.transformer_type = config.transformer_type
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -242,6 +313,14 @@ class TransformerBlockGemma2(nn.Module):
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
):
|
||||
if self.transformer_type == 'gemma3':
|
||||
if self.sliding_attention:
|
||||
if x.shape[1] > self.sliding_attention:
|
||||
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
|
||||
freqs_cis = freqs_cis[1]
|
||||
else:
|
||||
freqs_cis = freqs_cis[0]
|
||||
|
||||
# Self Attention
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
@@ -276,7 +355,7 @@ class Llama2_(nn.Module):
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
if self.config.transformer_type == "gemma2":
|
||||
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
|
||||
transformer = TransformerBlockGemma2
|
||||
self.normalize_in = True
|
||||
else:
|
||||
@@ -284,8 +363,8 @@ class Llama2_(nn.Module):
|
||||
self.normalize_in = False
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
transformer(config, device=device, dtype=dtype, ops=ops)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
@@ -305,6 +384,7 @@ class Llama2_(nn.Module):
|
||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
||||
position_ids,
|
||||
self.config.rope_theta,
|
||||
self.config.rope_scale,
|
||||
self.config.rope_dims,
|
||||
device=x.device)
|
||||
|
||||
@@ -433,3 +513,12 @@ class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Gemma3_4B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Gemma3_4B_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
@@ -11,23 +11,41 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
|
||||
|
||||
class NTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_4b", tokenizer=Gemma3_4BTokenizer)
|
||||
|
||||
class Gemma2_2BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class Gemma3_4BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
|
||||
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
|
||||
if model_type == "gemma2_2b":
|
||||
model = Gemma2_2BModel
|
||||
elif model_type == "gemma3_4b":
|
||||
model = Gemma3_4BModel
|
||||
|
||||
class LuminaTEModel_(LuminaModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
@@ -35,5 +53,5 @@ def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
|
||||
return LuminaTEModel_
|
||||
|
||||
Reference in New Issue
Block a user