Made AuraFlow work with optimized_attention_override

This commit is contained in:
Jedrzej Kosinski
2025-08-28 21:46:56 -07:00
parent 034d6c12e6
commit 17090c56be

View File

@@ -85,7 +85,7 @@ class SingleAttention(nn.Module):
) )
#@torch.compile() #@torch.compile()
def forward(self, c): def forward(self, c, transformer_options={}):
bsz, seqlen1, _ = c.shape bsz, seqlen1, _ = c.shape
@@ -95,7 +95,7 @@ class SingleAttention(nn.Module):
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim) v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
q, k = self.q_norm1(q), self.k_norm1(k) q, k = self.q_norm1(q), self.k_norm1(k)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
c = self.w1o(output) c = self.w1o(output)
return c return c
@@ -144,7 +144,7 @@ class DoubleAttention(nn.Module):
#@torch.compile() #@torch.compile()
def forward(self, c, x): def forward(self, c, x, transformer_options={}):
bsz, seqlen1, _ = c.shape bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape bsz, seqlen2, _ = x.shape
@@ -168,7 +168,7 @@ class DoubleAttention(nn.Module):
torch.cat([cv, xv], dim=1), torch.cat([cv, xv], dim=1),
) )
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
c, x = output.split([seqlen1, seqlen2], dim=1) c, x = output.split([seqlen1, seqlen2], dim=1)
c = self.w1o(c) c = self.w1o(c)
@@ -207,7 +207,7 @@ class MMDiTBlock(nn.Module):
self.is_last = is_last self.is_last = is_last
#@torch.compile() #@torch.compile()
def forward(self, c, x, global_cond, **kwargs): def forward(self, c, x, global_cond, transformer_options={}, **kwargs):
cres, xres = c, x cres, xres = c, x
@@ -225,7 +225,7 @@ class MMDiTBlock(nn.Module):
x = modulate(self.normX1(x), xshift_msa, xscale_msa) x = modulate(self.normX1(x), xshift_msa, xscale_msa)
# attention # attention
c, x = self.attn(c, x) c, x = self.attn(c, x, transformer_options=transformer_options)
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c) c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
@@ -255,13 +255,13 @@ class DiTBlock(nn.Module):
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations) self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
#@torch.compile() #@torch.compile()
def forward(self, cx, global_cond, **kwargs): def forward(self, cx, global_cond, transformer_options={}, **kwargs):
cxres = cx cxres = cx
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX( shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
global_cond global_cond
).chunk(6, dim=1) ).chunk(6, dim=1)
cx = modulate(self.norm1(cx), shift_msa, scale_msa) cx = modulate(self.norm1(cx), shift_msa, scale_msa)
cx = self.attn(cx) cx = self.attn(cx, transformer_options=transformer_options)
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx) cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp)) mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
cx = gate_mlp.unsqueeze(1) * mlpout cx = gate_mlp.unsqueeze(1) * mlpout
@@ -473,13 +473,14 @@ class MMDiT(nn.Module):
out = {} out = {}
out["txt"], out["img"] = layer(args["txt"], out["txt"], out["img"] = layer(args["txt"],
args["img"], args["img"],
args["vec"]) args["vec"],
transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
c = out["txt"] c = out["txt"]
x = out["img"] x = out["img"]
else: else:
c, x = layer(c, x, global_cond, **kwargs) c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)
if len(self.single_layers) > 0: if len(self.single_layers) > 0:
c_len = c.size(1) c_len = c.size(1)
@@ -488,13 +489,13 @@ class MMDiT(nn.Module):
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = layer(args["img"], args["vec"]) out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap}) out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
cx = out["img"] cx = out["img"]
else: else:
cx = layer(cx, global_cond, **kwargs) cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)
x = cx[:, c_len:] x = cx[:, c_len:]