mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 13:05:07 +00:00
Made hidream work with optimized_attention_override
This commit is contained in:
@@ -72,8 +72,8 @@ class TimestepEmbed(nn.Module):
|
|||||||
return t_emb
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}):
|
||||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options)
|
||||||
|
|
||||||
|
|
||||||
class HiDreamAttnProcessor_flashattn:
|
class HiDreamAttnProcessor_flashattn:
|
||||||
@@ -86,6 +86,7 @@ class HiDreamAttnProcessor_flashattn:
|
|||||||
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
||||||
text_tokens: Optional[torch.FloatTensor] = None,
|
text_tokens: Optional[torch.FloatTensor] = None,
|
||||||
rope: torch.FloatTensor = None,
|
rope: torch.FloatTensor = None,
|
||||||
|
transformer_options={},
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
@@ -133,7 +134,7 @@ class HiDreamAttnProcessor_flashattn:
|
|||||||
query = torch.cat([query_1, query_2], dim=-1)
|
query = torch.cat([query_1, query_2], dim=-1)
|
||||||
key = torch.cat([key_1, key_2], dim=-1)
|
key = torch.cat([key_1, key_2], dim=-1)
|
||||||
|
|
||||||
hidden_states = attention(query, key, value)
|
hidden_states = attention(query, key, value, transformer_options=transformer_options)
|
||||||
|
|
||||||
if not attn.single:
|
if not attn.single:
|
||||||
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
|
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
|
||||||
@@ -199,6 +200,7 @@ class HiDreamAttention(nn.Module):
|
|||||||
image_tokens_masks: torch.FloatTensor = None,
|
image_tokens_masks: torch.FloatTensor = None,
|
||||||
norm_text_tokens: torch.FloatTensor = None,
|
norm_text_tokens: torch.FloatTensor = None,
|
||||||
rope: torch.FloatTensor = None,
|
rope: torch.FloatTensor = None,
|
||||||
|
transformer_options={},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.processor(
|
return self.processor(
|
||||||
self,
|
self,
|
||||||
@@ -206,6 +208,7 @@ class HiDreamAttention(nn.Module):
|
|||||||
image_tokens_masks = image_tokens_masks,
|
image_tokens_masks = image_tokens_masks,
|
||||||
text_tokens = norm_text_tokens,
|
text_tokens = norm_text_tokens,
|
||||||
rope = rope,
|
rope = rope,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -406,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
|||||||
text_tokens: Optional[torch.FloatTensor] = None,
|
text_tokens: Optional[torch.FloatTensor] = None,
|
||||||
adaln_input: Optional[torch.FloatTensor] = None,
|
adaln_input: Optional[torch.FloatTensor] = None,
|
||||||
rope: torch.FloatTensor = None,
|
rope: torch.FloatTensor = None,
|
||||||
|
transformer_options={},
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
wtype = image_tokens.dtype
|
wtype = image_tokens.dtype
|
||||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
|
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
|
||||||
@@ -419,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
|||||||
norm_image_tokens,
|
norm_image_tokens,
|
||||||
image_tokens_masks,
|
image_tokens_masks,
|
||||||
rope = rope,
|
rope = rope,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||||
|
|
||||||
@@ -483,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
|||||||
text_tokens: Optional[torch.FloatTensor] = None,
|
text_tokens: Optional[torch.FloatTensor] = None,
|
||||||
adaln_input: Optional[torch.FloatTensor] = None,
|
adaln_input: Optional[torch.FloatTensor] = None,
|
||||||
rope: torch.FloatTensor = None,
|
rope: torch.FloatTensor = None,
|
||||||
|
transformer_options={},
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
wtype = image_tokens.dtype
|
wtype = image_tokens.dtype
|
||||||
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
|
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
|
||||||
@@ -500,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
|||||||
image_tokens_masks,
|
image_tokens_masks,
|
||||||
norm_text_tokens,
|
norm_text_tokens,
|
||||||
rope = rope,
|
rope = rope,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
||||||
@@ -550,6 +556,7 @@ class HiDreamImageBlock(nn.Module):
|
|||||||
text_tokens: Optional[torch.FloatTensor] = None,
|
text_tokens: Optional[torch.FloatTensor] = None,
|
||||||
adaln_input: torch.FloatTensor = None,
|
adaln_input: torch.FloatTensor = None,
|
||||||
rope: torch.FloatTensor = None,
|
rope: torch.FloatTensor = None,
|
||||||
|
transformer_options={},
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
return self.block(
|
return self.block(
|
||||||
image_tokens,
|
image_tokens,
|
||||||
@@ -557,6 +564,7 @@ class HiDreamImageBlock(nn.Module):
|
|||||||
text_tokens,
|
text_tokens,
|
||||||
adaln_input,
|
adaln_input,
|
||||||
rope,
|
rope,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -786,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
|||||||
text_tokens = cur_encoder_hidden_states,
|
text_tokens = cur_encoder_hidden_states,
|
||||||
adaln_input = adaln_input,
|
adaln_input = adaln_input,
|
||||||
rope = rope,
|
rope = rope,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
||||||
block_id += 1
|
block_id += 1
|
||||||
@@ -809,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
|||||||
text_tokens=None,
|
text_tokens=None,
|
||||||
adaln_input=adaln_input,
|
adaln_input=adaln_input,
|
||||||
rope=rope,
|
rope=rope,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
||||||
block_id += 1
|
block_id += 1
|
||||||
|
Reference in New Issue
Block a user