mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 13:35:05 +00:00
Made Mochi work with optimized_attention_override
This commit is contained in:
@@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module):
|
|||||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||||
crop_y,
|
crop_y,
|
||||||
|
transformer_options={},
|
||||||
**rope_rotation,
|
**rope_rotation,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
rope_cos = rope_rotation.get("rope_cos")
|
rope_cos = rope_rotation.get("rope_cos")
|
||||||
@@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module):
|
|||||||
|
|
||||||
xy = optimized_attention(q,
|
xy = optimized_attention(q,
|
||||||
k,
|
k,
|
||||||
v, self.num_heads, skip_reshape=True)
|
v, self.num_heads, skip_reshape=True, transformer_options=transformer_options)
|
||||||
|
|
||||||
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||||
x = self.proj_x(x)
|
x = self.proj_x(x)
|
||||||
@@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
c: torch.Tensor,
|
c: torch.Tensor,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
|
transformer_options={},
|
||||||
**attn_kwargs,
|
**attn_kwargs,
|
||||||
):
|
):
|
||||||
"""Forward pass of a block.
|
"""Forward pass of a block.
|
||||||
@@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module):
|
|||||||
y,
|
y,
|
||||||
scale_x=scale_msa_x,
|
scale_x=scale_msa_x,
|
||||||
scale_y=scale_msa_y,
|
scale_y=scale_msa_y,
|
||||||
|
transformer_options=transformer_options,
|
||||||
**attn_kwargs,
|
**attn_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
args["txt"],
|
args["txt"],
|
||||||
rope_cos=args["rope_cos"],
|
rope_cos=args["rope_cos"],
|
||||||
rope_sin=args["rope_sin"],
|
rope_sin=args["rope_sin"],
|
||||||
crop_y=args["num_tokens"]
|
crop_y=args["num_tokens"],
|
||||||
|
transformer_options=args["transformer_options"]
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
y_feat = out["txt"]
|
y_feat = out["txt"]
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
@@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
rope_cos=rope_cos,
|
rope_cos=rope_cos,
|
||||||
rope_sin=rope_sin,
|
rope_sin=rope_sin,
|
||||||
crop_y=num_tokens,
|
crop_y=num_tokens,
|
||||||
|
transformer_options=transformer_options,
|
||||||
) # (B, M, D), (B, L, D)
|
) # (B, M, D), (B, L, D)
|
||||||
del y_feat # Final layers don't use dense text features.
|
del y_feat # Final layers don't use dense text features.
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user