mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 04:55:53 +00:00
Made SD3 work with optimized_attention_override
This commit is contained in:
@@ -606,7 +606,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
|
|||||||
return _block_mixing(*args, **kwargs)
|
return _block_mixing(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _block_mixing(context, x, context_block, x_block, c):
|
def _block_mixing(context, x, context_block, x_block, c, transformer_options={}):
|
||||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||||
|
|
||||||
if x_block.x_block_self_attn:
|
if x_block.x_block_self_attn:
|
||||||
@@ -622,6 +622,7 @@ def _block_mixing(context, x, context_block, x_block, c):
|
|||||||
attn = optimized_attention(
|
attn = optimized_attention(
|
||||||
qkv[0], qkv[1], qkv[2],
|
qkv[0], qkv[1], qkv[2],
|
||||||
heads=x_block.attn.num_heads,
|
heads=x_block.attn.num_heads,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
context_attn, x_attn = (
|
context_attn, x_attn = (
|
||||||
attn[:, : context_qkv[0].shape[1]],
|
attn[:, : context_qkv[0].shape[1]],
|
||||||
@@ -637,6 +638,7 @@ def _block_mixing(context, x, context_block, x_block, c):
|
|||||||
attn2 = optimized_attention(
|
attn2 = optimized_attention(
|
||||||
x_qkv2[0], x_qkv2[1], x_qkv2[2],
|
x_qkv2[0], x_qkv2[1], x_qkv2[2],
|
||||||
heads=x_block.attn2.num_heads,
|
heads=x_block.attn2.num_heads,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
||||||
else:
|
else:
|
||||||
@@ -958,10 +960,10 @@ class MMDiT(nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
|
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
context = out["txt"]
|
context = out["txt"]
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
@@ -970,6 +972,7 @@ class MMDiT(nn.Module):
|
|||||||
x,
|
x,
|
||||||
c=c_mod,
|
c=c_mod,
|
||||||
use_checkpoint=self.use_checkpoint,
|
use_checkpoint=self.use_checkpoint,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
if control is not None:
|
if control is not None:
|
||||||
control_o = control.get("output")
|
control_o = control.get("output")
|
||||||
|
Reference in New Issue
Block a user