From fe31ad02768c66c61b3dc12f5d4bdfe8990ce25c Mon Sep 17 00:00:00 2001 From: contentis Date: Sat, 23 Aug 2025 01:39:15 +0200 Subject: [PATCH] Add elementwise fusions (#9495) * Add elementwise fusions * Add addcmul pattern to Qwen --- comfy/ldm/modules/diffusionmodules/mmdit.py | 12 +++++++----- comfy/ldm/qwen_image/model.py | 12 ++++++------ comfy/ldm/wan/model.py | 14 +++++++------- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index eaf3e73a4..4d6beba2d 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -109,7 +109,7 @@ class PatchEmbed(nn.Module): def modulate(x, shift, scale): if shift is None: shift = torch.zeros_like(scale) - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return torch.addcmul(shift.unsqueeze(1), x, 1+ scale.unsqueeze(1)) ################################################################################# @@ -564,10 +564,7 @@ class DismantledBlock(nn.Module): assert not self.pre_only attn1 = self.attn.post_attention(attn) attn2 = self.attn2.post_attention(attn2) - out1 = gate_msa.unsqueeze(1) * attn1 - out2 = gate_msa2.unsqueeze(1) * attn2 - x = x + out1 - x = x + out2 + x = gate_cat(x, gate_msa, gate_msa2, attn1, attn2) x = x + gate_mlp.unsqueeze(1) * self.mlp( modulate(self.norm2(x), shift_mlp, scale_mlp) ) @@ -594,6 +591,11 @@ class DismantledBlock(nn.Module): ) return self.post_attention(attn, *intermediates) +def gate_cat(x, gate_msa, gate_msa2, attn1, attn2): + out1 = gate_msa.unsqueeze(1) * attn1 + out2 = gate_msa2.unsqueeze(1) * attn2 + x = torch.stack([x, out1, out2], dim=0).sum(dim=0) + return x def block_mixing(*args, use_checkpoint=True, **kwargs): if use_checkpoint: diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index d0e39833a..af00ff119 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -214,9 +214,9 @@ class QwenImageTransformerBlock(nn.Module): operations=operations, ) - def _modulate(self, x, mod_params): - shift, scale, gate = mod_params.chunk(3, dim=-1) - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + shift, scale, gate = torch.chunk(mod_params, 3, dim=-1) + return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1) def forward( self, @@ -248,11 +248,11 @@ class QwenImageTransformerBlock(nn.Module): img_normed2 = self.img_norm2(hidden_states) img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) - hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2) + hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2)) txt_normed2 = self.txt_norm2(encoder_hidden_states) txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) - encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2) + encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2)) return encoder_hidden_states, hidden_states @@ -275,7 +275,7 @@ class LastLayer(nn.Module): def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: emb = self.linear(self.silu(conditioning_embedding)) scale, shift = torch.chunk(emb, 2, dim=1) - x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :]) return x diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 9d3741be3..0726b8e1b 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -148,8 +148,8 @@ WAN_CROSSATTENTION_CLASSES = { def repeat_e(e, x): repeats = 1 - if e.shape[1] > 1: - repeats = x.shape[1] // e.shape[1] + if e.size(1) > 1: + repeats = x.size(1) // e.size(1) if repeats == 1: return e return torch.repeat_interleave(e, repeats, dim=1) @@ -219,15 +219,15 @@ class WanAttentionBlock(nn.Module): # self-attention y = self.self_attn( - self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x), + torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), freqs) - x = x + y * repeat_e(e[2], x) + x = torch.addcmul(x, y, repeat_e(e[2], x)) # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) - y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x)) - x = x + y * repeat_e(e[5], x) + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) + x = torch.addcmul(x, y, repeat_e(e[5], x)) return x @@ -342,7 +342,7 @@ class Head(nn.Module): else: e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2) - x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x))) + x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x)))) return x