mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-08 15:17:14 +00:00
Change lumina to native RMSNorm. (#7935)
This commit is contained in:
parent
9187a09483
commit
80a44b97f5
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
|
||||||
@ -64,8 +64,8 @@ class JointAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if qk_norm:
|
if qk_norm:
|
||||||
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
else:
|
else:
|
||||||
self.q_norm = self.k_norm = nn.Identity()
|
self.q_norm = self.k_norm = nn.Identity()
|
||||||
|
|
||||||
@ -242,11 +242,11 @@ class JointTransformerBlock(nn.Module):
|
|||||||
operation_settings=operation_settings,
|
operation_settings=operation_settings,
|
||||||
)
|
)
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
self.modulation = modulation
|
self.modulation = modulation
|
||||||
if modulation:
|
if modulation:
|
||||||
@ -431,7 +431,7 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
||||||
self.cap_embedder = nn.Sequential(
|
self.cap_embedder = nn.Sequential(
|
||||||
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings),
|
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||||
operation_settings.get("operations").Linear(
|
operation_settings.get("operations").Linear(
|
||||||
cap_feat_dim,
|
cap_feat_dim,
|
||||||
dim,
|
dim,
|
||||||
@ -457,7 +457,7 @@ class NextDiT(nn.Module):
|
|||||||
for layer_id in range(n_layers)
|
for layer_id in range(n_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
||||||
|
|
||||||
assert (dim // n_heads) == sum(axes_dims)
|
assert (dim // n_heads) == sum(axes_dims)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user