mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 21:16:09 +00:00
Made Omnigen 2 work with optimized_attention_override
This commit is contained in:
@@ -120,7 +120,7 @@ class Attention(nn.Module):
|
|||||||
nn.Dropout(0.0)
|
nn.Dropout(0.0)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
|
||||||
batch_size, sequence_length, _ = hidden_states.shape
|
batch_size, sequence_length, _ = hidden_states.shape
|
||||||
|
|
||||||
query = self.to_q(hidden_states)
|
query = self.to_q(hidden_states)
|
||||||
@@ -146,7 +146,7 @@ class Attention(nn.Module):
|
|||||||
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||||
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
||||||
|
|
||||||
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
|
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||||
hidden_states = self.to_out[0](hidden_states)
|
hidden_states = self.to_out[0](hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -182,16 +182,16 @@ class OmniGen2TransformerBlock(nn.Module):
|
|||||||
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||||
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
|
||||||
if self.modulation:
|
if self.modulation:
|
||||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
|
||||||
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
||||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
||||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
||||||
else:
|
else:
|
||||||
norm_hidden_states = self.norm1(hidden_states)
|
norm_hidden_states = self.norm1(hidden_states)
|
||||||
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
|
||||||
hidden_states = hidden_states + self.norm2(attn_output)
|
hidden_states = hidden_states + self.norm2(attn_output)
|
||||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
||||||
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
||||||
@@ -390,7 +390,7 @@ class OmniGen2Transformer2DModel(nn.Module):
|
|||||||
ref_img_sizes, img_sizes,
|
ref_img_sizes, img_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
|
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, transformer_options={}):
|
||||||
batch_size = len(hidden_states)
|
batch_size = len(hidden_states)
|
||||||
|
|
||||||
hidden_states = self.x_embedder(hidden_states)
|
hidden_states = self.x_embedder(hidden_states)
|
||||||
@@ -405,17 +405,17 @@ class OmniGen2Transformer2DModel(nn.Module):
|
|||||||
shift += ref_img_len
|
shift += ref_img_len
|
||||||
|
|
||||||
for layer in self.noise_refiner:
|
for layer in self.noise_refiner:
|
||||||
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
|
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb, transformer_options=transformer_options)
|
||||||
|
|
||||||
if ref_image_hidden_states is not None:
|
if ref_image_hidden_states is not None:
|
||||||
for layer in self.ref_image_refiner:
|
for layer in self.ref_image_refiner:
|
||||||
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
|
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb, transformer_options=transformer_options)
|
||||||
|
|
||||||
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
|
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
|
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||||
_, _, H_padded, W_padded = hidden_states.shape
|
_, _, H_padded, W_padded = hidden_states.shape
|
||||||
@@ -444,7 +444,7 @@ class OmniGen2Transformer2DModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for layer in self.context_refiner:
|
for layer in self.context_refiner:
|
||||||
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
|
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options)
|
||||||
|
|
||||||
img_len = hidden_states.shape[1]
|
img_len = hidden_states.shape[1]
|
||||||
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
||||||
@@ -453,13 +453,14 @@ class OmniGen2Transformer2DModel(nn.Module):
|
|||||||
noise_rotary_emb, ref_img_rotary_emb,
|
noise_rotary_emb, ref_img_rotary_emb,
|
||||||
l_effective_ref_img_len, l_effective_img_len,
|
l_effective_ref_img_len, l_effective_img_len,
|
||||||
temb,
|
temb,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
|
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb, transformer_options=transformer_options)
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user