From fa62287f1f47f6ed30de077d5623bf07b805f7a0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Feb 2025 05:22:29 -0500 Subject: [PATCH] More code reuse in wan. Fix bug when changing the compute dtype on wan. --- comfy/ldm/wan/model.py | 27 +++++---------------------- comfy/model_patcher.py | 2 +- 2 files changed, 6 insertions(+), 23 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 6533039f7..e88a1834b 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -9,9 +9,11 @@ from einops import repeat from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope +from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm import comfy.ldm.common_dit import comfy.model_management + def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 @@ -25,25 +27,6 @@ def sinusoidal_embedding_1d(dim, position): return x -class WanRMSNorm(nn.Module): - - def __init__(self, dim, eps=1e-5, device=None, dtype=None): - super().__init__() - self.dim = dim - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim, device=device, dtype=dtype)) - - def forward(self, x): - r""" - Args: - x(Tensor): Shape [B, L, C] - """ - return self._norm(x.float()).type_as(x) * comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) - - class WanSelfAttention(nn.Module): def __init__(self, @@ -66,8 +49,8 @@ class WanSelfAttention(nn.Module): self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.norm_q = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - self.norm_k = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() + self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() + self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() def forward(self, x, freqs): r""" @@ -131,7 +114,7 @@ class WanI2VCrossAttention(WanSelfAttention): self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) # self.alpha = nn.Parameter(torch.zeros((1, ))) - self.norm_k_img = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() + self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() def forward(self, x, context): r""" diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4dbe1b7aa..8a1f8fb63 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -639,7 +639,7 @@ class ModelPatcher: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) - if cast_weight: + if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True