Add elementwise fusions (#9495)

* Add elementwise fusions

* Add addcmul pattern to Qwen
This commit is contained in:
contentis
2025-08-23 01:39:15 +02:00
committed by GitHub
parent ca4e96a8ae
commit fe31ad0276
3 changed files with 20 additions and 18 deletions

View File

@@ -109,7 +109,7 @@ class PatchEmbed(nn.Module):
def modulate(x, shift, scale): def modulate(x, shift, scale):
if shift is None: if shift is None:
shift = torch.zeros_like(scale) shift = torch.zeros_like(scale)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) return torch.addcmul(shift.unsqueeze(1), x, 1+ scale.unsqueeze(1))
################################################################################# #################################################################################
@@ -564,10 +564,7 @@ class DismantledBlock(nn.Module):
assert not self.pre_only assert not self.pre_only
attn1 = self.attn.post_attention(attn) attn1 = self.attn.post_attention(attn)
attn2 = self.attn2.post_attention(attn2) attn2 = self.attn2.post_attention(attn2)
out1 = gate_msa.unsqueeze(1) * attn1 x = gate_cat(x, gate_msa, gate_msa2, attn1, attn2)
out2 = gate_msa2.unsqueeze(1) * attn2
x = x + out1
x = x + out2
x = x + gate_mlp.unsqueeze(1) * self.mlp( x = x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp) modulate(self.norm2(x), shift_mlp, scale_mlp)
) )
@@ -594,6 +591,11 @@ class DismantledBlock(nn.Module):
) )
return self.post_attention(attn, *intermediates) return self.post_attention(attn, *intermediates)
def gate_cat(x, gate_msa, gate_msa2, attn1, attn2):
out1 = gate_msa.unsqueeze(1) * attn1
out2 = gate_msa2.unsqueeze(1) * attn2
x = torch.stack([x, out1, out2], dim=0).sum(dim=0)
return x
def block_mixing(*args, use_checkpoint=True, **kwargs): def block_mixing(*args, use_checkpoint=True, **kwargs):
if use_checkpoint: if use_checkpoint:

View File

@@ -214,9 +214,9 @@ class QwenImageTransformerBlock(nn.Module):
operations=operations, operations=operations,
) )
def _modulate(self, x, mod_params): def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = mod_params.chunk(3, dim=-1) shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
def forward( def forward(
self, self,
@@ -248,11 +248,11 @@ class QwenImageTransformerBlock(nn.Module):
img_normed2 = self.img_norm2(hidden_states) img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2) hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
txt_normed2 = self.txt_norm2(encoder_hidden_states) txt_normed2 = self.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2) encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
@@ -275,7 +275,7 @@ class LastLayer(nn.Module):
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(conditioning_embedding)) emb = self.linear(self.silu(conditioning_embedding))
scale, shift = torch.chunk(emb, 2, dim=1) scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :])
return x return x

View File

@@ -148,8 +148,8 @@ WAN_CROSSATTENTION_CLASSES = {
def repeat_e(e, x): def repeat_e(e, x):
repeats = 1 repeats = 1
if e.shape[1] > 1: if e.size(1) > 1:
repeats = x.shape[1] // e.shape[1] repeats = x.size(1) // e.size(1)
if repeats == 1: if repeats == 1:
return e return e
return torch.repeat_interleave(e, repeats, dim=1) return torch.repeat_interleave(e, repeats, dim=1)
@@ -219,15 +219,15 @@ class WanAttentionBlock(nn.Module):
# self-attention # self-attention
y = self.self_attn( y = self.self_attn(
self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x), torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs) freqs)
x = x + y * repeat_e(e[2], x) x = torch.addcmul(x, y, repeat_e(e[2], x))
# cross-attention & ffn # cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x)) y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = x + y * repeat_e(e[5], x) x = torch.addcmul(x, y, repeat_e(e[5], x))
return x return x
@@ -342,7 +342,7 @@ class Head(nn.Module):
else: else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2) e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x))) x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x))))
return x return x