Made hidream work with optimized_attention_override

This commit is contained in:
Jedrzej Kosinski
2025-08-28 20:10:50 -07:00
parent f752715aac
commit 4cafd58f71

View File

@@ -72,8 +72,8 @@ class TimestepEmbed(nn.Module):
return t_emb return t_emb
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}):
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2]) return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options)
class HiDreamAttnProcessor_flashattn: class HiDreamAttnProcessor_flashattn:
@@ -86,6 +86,7 @@ class HiDreamAttnProcessor_flashattn:
image_tokens_masks: Optional[torch.FloatTensor] = None, image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None, rope: torch.FloatTensor = None,
transformer_options={},
*args, *args,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
@@ -133,7 +134,7 @@ class HiDreamAttnProcessor_flashattn:
query = torch.cat([query_1, query_2], dim=-1) query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1) key = torch.cat([key_1, key_2], dim=-1)
hidden_states = attention(query, key, value) hidden_states = attention(query, key, value, transformer_options=transformer_options)
if not attn.single: if not attn.single:
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
@@ -199,6 +200,7 @@ class HiDreamAttention(nn.Module):
image_tokens_masks: torch.FloatTensor = None, image_tokens_masks: torch.FloatTensor = None,
norm_text_tokens: torch.FloatTensor = None, norm_text_tokens: torch.FloatTensor = None,
rope: torch.FloatTensor = None, rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.Tensor: ) -> torch.Tensor:
return self.processor( return self.processor(
self, self,
@@ -206,6 +208,7 @@ class HiDreamAttention(nn.Module):
image_tokens_masks = image_tokens_masks, image_tokens_masks = image_tokens_masks,
text_tokens = norm_text_tokens, text_tokens = norm_text_tokens,
rope = rope, rope = rope,
transformer_options=transformer_options,
) )
@@ -406,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None, adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None, rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor: ) -> torch.FloatTensor:
wtype = image_tokens.dtype wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \ shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
@@ -419,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
norm_image_tokens, norm_image_tokens,
image_tokens_masks, image_tokens_masks,
rope = rope, rope = rope,
transformer_options=transformer_options,
) )
image_tokens = gate_msa_i * attn_output_i + image_tokens image_tokens = gate_msa_i * attn_output_i + image_tokens
@@ -483,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None, adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None, rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor: ) -> torch.FloatTensor:
wtype = image_tokens.dtype wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
@@ -500,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module):
image_tokens_masks, image_tokens_masks,
norm_text_tokens, norm_text_tokens,
rope = rope, rope = rope,
transformer_options=transformer_options,
) )
image_tokens = gate_msa_i * attn_output_i + image_tokens image_tokens = gate_msa_i * attn_output_i + image_tokens
@@ -550,6 +556,7 @@ class HiDreamImageBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: torch.FloatTensor = None, adaln_input: torch.FloatTensor = None,
rope: torch.FloatTensor = None, rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor: ) -> torch.FloatTensor:
return self.block( return self.block(
image_tokens, image_tokens,
@@ -557,6 +564,7 @@ class HiDreamImageBlock(nn.Module):
text_tokens, text_tokens,
adaln_input, adaln_input,
rope, rope,
transformer_options=transformer_options,
) )
@@ -786,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
text_tokens = cur_encoder_hidden_states, text_tokens = cur_encoder_hidden_states,
adaln_input = adaln_input, adaln_input = adaln_input,
rope = rope, rope = rope,
transformer_options=transformer_options,
) )
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1 block_id += 1
@@ -809,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
text_tokens=None, text_tokens=None,
adaln_input=adaln_input, adaln_input=adaln_input,
rope=rope, rope=rope,
transformer_options=transformer_options,
) )
hidden_states = hidden_states[:, :hidden_states_seq_len] hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1 block_id += 1