mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-09 23:57:14 +00:00
More code reuse in wan.
Fix bug when changing the compute dtype on wan.
This commit is contained in:
parent
0844998db3
commit
fa62287f1f
@ -9,9 +9,11 @@ from einops import repeat
|
|||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
from comfy.ldm.flux.math import apply_rope
|
from comfy.ldm.flux.math import apply_rope
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
def sinusoidal_embedding_1d(dim, position):
|
def sinusoidal_embedding_1d(dim, position):
|
||||||
# preprocess
|
# preprocess
|
||||||
assert dim % 2 == 0
|
assert dim % 2 == 0
|
||||||
@ -25,25 +27,6 @@ def sinusoidal_embedding_1d(dim, position):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class WanRMSNorm(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, dim, eps=1e-5, device=None, dtype=None):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.eps = eps
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim, device=device, dtype=dtype))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
x(Tensor): Shape [B, L, C]
|
|
||||||
"""
|
|
||||||
return self._norm(x.float()).type_as(x) * comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device)
|
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
|
||||||
|
|
||||||
|
|
||||||
class WanSelfAttention(nn.Module):
|
class WanSelfAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -66,8 +49,8 @@ class WanSelfAttention(nn.Module):
|
|||||||
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.norm_q = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
self.norm_k = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, freqs):
|
def forward(self, x, freqs):
|
||||||
r"""
|
r"""
|
||||||
@ -131,7 +114,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|||||||
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||||
self.norm_k_img = WanRMSNorm(dim, eps=eps, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, context):
|
def forward(self, x, context):
|
||||||
r"""
|
r"""
|
||||||
|
@ -639,7 +639,7 @@ class ModelPatcher:
|
|||||||
mem_counter += module_mem
|
mem_counter += module_mem
|
||||||
load_completely.append((module_mem, n, m, params))
|
load_completely.append((module_mem, n, m, params))
|
||||||
|
|
||||||
if cast_weight:
|
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user