diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 056e101a4..ad9a7daea 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -261,8 +261,8 @@ class CrossAttention(nn.Module): self.heads = heads self.dim_head = dim_head - self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device) - self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device) + self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)