Change lumina to native RMSNorm. (#7935)

This commit is contained in:
comfyanonymous 2025-05-04 03:39:23 -07:00 committed by GitHub
parent 9187a09483
commit 80a44b97f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import comfy.ldm.common_dit import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.layers import EmbedND
@ -64,8 +64,8 @@ class JointAttention(nn.Module):
) )
if qk_norm: if qk_norm:
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings) self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings) self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
else: else:
self.q_norm = self.k_norm = nn.Identity() self.q_norm = self.k_norm = nn.Identity()
@ -242,11 +242,11 @@ class JointTransformerBlock(nn.Module):
operation_settings=operation_settings, operation_settings=operation_settings,
) )
self.layer_id = layer_id self.layer_id = layer_id
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.modulation = modulation self.modulation = modulation
if modulation: if modulation:
@ -431,7 +431,7 @@ class NextDiT(nn.Module):
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
self.cap_embedder = nn.Sequential( self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings), operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear( operation_settings.get("operations").Linear(
cap_feat_dim, cap_feat_dim,
dim, dim,
@ -457,7 +457,7 @@ class NextDiT(nn.Module):
for layer_id in range(n_layers) for layer_id in range(n_layers)
] ]
) )
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
assert (dim // n_heads) == sum(axes_dims) assert (dim // n_heads) == sum(axes_dims)