Made Lumina work with optimized_attention_override

This commit is contained in:
Jedrzej Kosinski
2025-08-28 22:00:44 -07:00
parent 17090c56be
commit d644aba6bc

View File

@@ -104,6 +104,7 @@ class JointAttention(nn.Module):
x: torch.Tensor, x: torch.Tensor,
x_mask: torch.Tensor, x_mask: torch.Tensor,
freqs_cis: torch.Tensor, freqs_cis: torch.Tensor,
transformer_options={},
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@@ -140,7 +141,7 @@ class JointAttention(nn.Module):
if n_rep >= 1: if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True) output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
return self.out(output) return self.out(output)
@@ -268,6 +269,7 @@ class JointTransformerBlock(nn.Module):
x_mask: torch.Tensor, x_mask: torch.Tensor,
freqs_cis: torch.Tensor, freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor]=None, adaln_input: Optional[torch.Tensor]=None,
transformer_options={},
): ):
""" """
Perform a forward pass through the TransformerBlock. Perform a forward pass through the TransformerBlock.
@@ -290,6 +292,7 @@ class JointTransformerBlock(nn.Module):
modulate(self.attention_norm1(x), scale_msa), modulate(self.attention_norm1(x), scale_msa),
x_mask, x_mask,
freqs_cis, freqs_cis,
transformer_options=transformer_options,
) )
) )
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
@@ -304,6 +307,7 @@ class JointTransformerBlock(nn.Module):
self.attention_norm1(x), self.attention_norm1(x),
x_mask, x_mask,
freqs_cis, freqs_cis,
transformer_options=transformer_options,
) )
) )
x = x + self.ffn_norm2( x = x + self.ffn_norm2(
@@ -494,7 +498,7 @@ class NextDiT(nn.Module):
return imgs return imgs
def patchify_and_embed( def patchify_and_embed(
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = len(x) bsz = len(x)
pH = pW = self.patch_size pH = pW = self.patch_size
@@ -554,7 +558,7 @@ class NextDiT(nn.Module):
# refine context # refine context
for layer in self.context_refiner: for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
# refine image # refine image
flat_x = [] flat_x = []
@@ -573,7 +577,7 @@ class NextDiT(nn.Module):
padded_img_embed = self.x_embedder(padded_img_embed) padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1) padded_img_mask = padded_img_mask.unsqueeze(1)
for layer in self.noise_refiner: for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
if cap_mask is not None: if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device) mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
@@ -616,12 +620,13 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
transformer_options = kwargs.get("transformer_options", {})
x_is_tensor = isinstance(x, torch.Tensor) x_is_tensor = isinstance(x, torch.Tensor)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens) x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(x.device) freqs_cis = freqs_cis.to(x.device)
for layer in self.layers: for layer in self.layers:
x = layer(x, mask, freqs_cis, adaln_input) x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
x = self.final_layer(x, adaln_input) x = self.final_layer(x, adaln_input)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]