This commit is contained in:
comfyanonymous
2025-04-15 12:13:28 -04:00
parent 3e8155f7a3
commit 6fc5dbd52a
2 changed files with 0 additions and 13 deletions

View File

@@ -27,17 +27,6 @@ def rms_norm(x, weight=None, eps=1e-6):
if RMSNorm is None:
class RMSNorm(torch.nn.Module):
def __init__(
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None, **kwargs
):
super().__init__()
self.eps = eps
self.learnable_scale = elementwise_affine
if self.learnable_scale:
self.weight = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
else:
self.register_parameter("weight", None)
def __init__(
self,
normalized_shape,