mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 21:45:06 +00:00
Made Lumina work with optimized_attention_override
This commit is contained in:
@@ -104,6 +104,7 @@ class JointAttention(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_mask: torch.Tensor,
|
x_mask: torch.Tensor,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
|
transformer_options={},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -140,7 +141,7 @@ class JointAttention(nn.Module):
|
|||||||
if n_rep >= 1:
|
if n_rep >= 1:
|
||||||
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||||
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||||
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
|
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||||
|
|
||||||
return self.out(output)
|
return self.out(output)
|
||||||
|
|
||||||
@@ -268,6 +269,7 @@ class JointTransformerBlock(nn.Module):
|
|||||||
x_mask: torch.Tensor,
|
x_mask: torch.Tensor,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
adaln_input: Optional[torch.Tensor]=None,
|
adaln_input: Optional[torch.Tensor]=None,
|
||||||
|
transformer_options={},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform a forward pass through the TransformerBlock.
|
Perform a forward pass through the TransformerBlock.
|
||||||
@@ -290,6 +292,7 @@ class JointTransformerBlock(nn.Module):
|
|||||||
modulate(self.attention_norm1(x), scale_msa),
|
modulate(self.attention_norm1(x), scale_msa),
|
||||||
x_mask,
|
x_mask,
|
||||||
freqs_cis,
|
freqs_cis,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||||
@@ -304,6 +307,7 @@ class JointTransformerBlock(nn.Module):
|
|||||||
self.attention_norm1(x),
|
self.attention_norm1(x),
|
||||||
x_mask,
|
x_mask,
|
||||||
freqs_cis,
|
freqs_cis,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
x = x + self.ffn_norm2(
|
x = x + self.ffn_norm2(
|
||||||
@@ -494,7 +498,7 @@ class NextDiT(nn.Module):
|
|||||||
return imgs
|
return imgs
|
||||||
|
|
||||||
def patchify_and_embed(
|
def patchify_and_embed(
|
||||||
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
|
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||||
bsz = len(x)
|
bsz = len(x)
|
||||||
pH = pW = self.patch_size
|
pH = pW = self.patch_size
|
||||||
@@ -554,7 +558,7 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
# refine context
|
# refine context
|
||||||
for layer in self.context_refiner:
|
for layer in self.context_refiner:
|
||||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
|
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
|
||||||
|
|
||||||
# refine image
|
# refine image
|
||||||
flat_x = []
|
flat_x = []
|
||||||
@@ -573,7 +577,7 @@ class NextDiT(nn.Module):
|
|||||||
padded_img_embed = self.x_embedder(padded_img_embed)
|
padded_img_embed = self.x_embedder(padded_img_embed)
|
||||||
padded_img_mask = padded_img_mask.unsqueeze(1)
|
padded_img_mask = padded_img_mask.unsqueeze(1)
|
||||||
for layer in self.noise_refiner:
|
for layer in self.noise_refiner:
|
||||||
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
|
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
|
||||||
|
|
||||||
if cap_mask is not None:
|
if cap_mask is not None:
|
||||||
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
||||||
@@ -616,12 +620,13 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||||
|
|
||||||
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
x_is_tensor = isinstance(x, torch.Tensor)
|
x_is_tensor = isinstance(x, torch.Tensor)
|
||||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
|
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||||
freqs_cis = freqs_cis.to(x.device)
|
freqs_cis = freqs_cis.to(x.device)
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x, mask, freqs_cis, adaln_input)
|
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||||
|
|
||||||
x = self.final_layer(x, adaln_input)
|
x = self.final_layer(x, adaln_input)
|
||||||
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
||||||
|
Reference in New Issue
Block a user