mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 20:17:30 +00:00
Made AuraFlow work with optimized_attention_override
This commit is contained in:
@@ -85,7 +85,7 @@ class SingleAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
#@torch.compile()
|
#@torch.compile()
|
||||||
def forward(self, c):
|
def forward(self, c, transformer_options={}):
|
||||||
|
|
||||||
bsz, seqlen1, _ = c.shape
|
bsz, seqlen1, _ = c.shape
|
||||||
|
|
||||||
@@ -95,7 +95,7 @@ class SingleAttention(nn.Module):
|
|||||||
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||||
q, k = self.q_norm1(q), self.k_norm1(k)
|
q, k = self.q_norm1(q), self.k_norm1(k)
|
||||||
|
|
||||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||||
c = self.w1o(output)
|
c = self.w1o(output)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@@ -144,7 +144,7 @@ class DoubleAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
#@torch.compile()
|
#@torch.compile()
|
||||||
def forward(self, c, x):
|
def forward(self, c, x, transformer_options={}):
|
||||||
|
|
||||||
bsz, seqlen1, _ = c.shape
|
bsz, seqlen1, _ = c.shape
|
||||||
bsz, seqlen2, _ = x.shape
|
bsz, seqlen2, _ = x.shape
|
||||||
@@ -168,7 +168,7 @@ class DoubleAttention(nn.Module):
|
|||||||
torch.cat([cv, xv], dim=1),
|
torch.cat([cv, xv], dim=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||||
|
|
||||||
c, x = output.split([seqlen1, seqlen2], dim=1)
|
c, x = output.split([seqlen1, seqlen2], dim=1)
|
||||||
c = self.w1o(c)
|
c = self.w1o(c)
|
||||||
@@ -207,7 +207,7 @@ class MMDiTBlock(nn.Module):
|
|||||||
self.is_last = is_last
|
self.is_last = is_last
|
||||||
|
|
||||||
#@torch.compile()
|
#@torch.compile()
|
||||||
def forward(self, c, x, global_cond, **kwargs):
|
def forward(self, c, x, global_cond, transformer_options={}, **kwargs):
|
||||||
|
|
||||||
cres, xres = c, x
|
cres, xres = c, x
|
||||||
|
|
||||||
@@ -225,7 +225,7 @@ class MMDiTBlock(nn.Module):
|
|||||||
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
|
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
c, x = self.attn(c, x)
|
c, x = self.attn(c, x, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
|
||||||
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
|
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
|
||||||
@@ -255,13 +255,13 @@ class DiTBlock(nn.Module):
|
|||||||
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
#@torch.compile()
|
#@torch.compile()
|
||||||
def forward(self, cx, global_cond, **kwargs):
|
def forward(self, cx, global_cond, transformer_options={}, **kwargs):
|
||||||
cxres = cx
|
cxres = cx
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
|
||||||
global_cond
|
global_cond
|
||||||
).chunk(6, dim=1)
|
).chunk(6, dim=1)
|
||||||
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
|
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
|
||||||
cx = self.attn(cx)
|
cx = self.attn(cx, transformer_options=transformer_options)
|
||||||
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
|
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
|
||||||
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
|
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
|
||||||
cx = gate_mlp.unsqueeze(1) * mlpout
|
cx = gate_mlp.unsqueeze(1) * mlpout
|
||||||
@@ -473,13 +473,14 @@ class MMDiT(nn.Module):
|
|||||||
out = {}
|
out = {}
|
||||||
out["txt"], out["img"] = layer(args["txt"],
|
out["txt"], out["img"] = layer(args["txt"],
|
||||||
args["img"],
|
args["img"],
|
||||||
args["vec"])
|
args["vec"],
|
||||||
|
transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
c = out["txt"]
|
c = out["txt"]
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
c, x = layer(c, x, global_cond, **kwargs)
|
c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)
|
||||||
|
|
||||||
if len(self.single_layers) > 0:
|
if len(self.single_layers) > 0:
|
||||||
c_len = c.size(1)
|
c_len = c.size(1)
|
||||||
@@ -488,13 +489,13 @@ class MMDiT(nn.Module):
|
|||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = layer(args["img"], args["vec"])
|
out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
|
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
cx = out["img"]
|
cx = out["img"]
|
||||||
else:
|
else:
|
||||||
cx = layer(cx, global_cond, **kwargs)
|
cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)
|
||||||
|
|
||||||
x = cx[:, c_len:]
|
x = cx[:, c_len:]
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user