Merge branch 'attention-select' into sortblock

This commit is contained in:
Jedrzej Kosinski
2025-08-30 20:04:48 -07:00
26 changed files with 316 additions and 179 deletions

View File

@@ -133,6 +133,7 @@ class Attention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
transformer_options={},
**cross_attention_kwargs, **cross_attention_kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
return self.processor( return self.processor(
@@ -140,6 +141,7 @@ class Attention(nn.Module):
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
transformer_options=transformer_options,
**cross_attention_kwargs, **cross_attention_kwargs,
) )
@@ -366,6 +368,7 @@ class CustomerAttnProcessor2_0:
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
transformer_options={},
*args, *args,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
@@ -433,7 +436,7 @@ class CustomerAttnProcessor2_0:
# the output of sdp = (batch, num_heads, seq_len, head_dim) # the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = optimized_attention( hidden_states = optimized_attention(
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
).to(query.dtype) ).to(query.dtype)
# linear proj # linear proj
@@ -697,6 +700,7 @@ class LinearTransformerBlock(nn.Module):
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
temb: torch.FloatTensor = None, temb: torch.FloatTensor = None,
transformer_options={},
): ):
N = hidden_states.shape[0] N = hidden_states.shape[0]
@@ -720,6 +724,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross, rotary_freqs_cis_cross=rotary_freqs_cis_cross,
transformer_options=transformer_options,
) )
else: else:
attn_output, _ = self.attn( attn_output, _ = self.attn(
@@ -729,6 +734,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=None, encoder_attention_mask=None,
rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=None, rotary_freqs_cis_cross=None,
transformer_options=transformer_options,
) )
if self.use_adaln_single: if self.use_adaln_single:
@@ -743,6 +749,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross, rotary_freqs_cis_cross=rotary_freqs_cis_cross,
transformer_options=transformer_options,
) )
hidden_states = attn_output + hidden_states hidden_states = attn_output + hidden_states

View File

@@ -314,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length: int = 0, output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0, controlnet_scale: Union[float, torch.Tensor] = 1.0,
transformer_options={},
): ):
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype)) embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
temb = self.t_block(embedded_timestep) temb = self.t_block(embedded_timestep)
@@ -339,6 +340,7 @@ class ACEStepTransformer2DModel(nn.Module):
rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=encoder_rotary_freqs_cis, rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
temb=temb, temb=temb,
transformer_options=transformer_options,
) )
output = self.final_layer(hidden_states, embedded_timestep, output_length) output = self.final_layer(hidden_states, embedded_timestep, output_length)
@@ -393,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length = hidden_states.shape[-1] output_length = hidden_states.shape[-1]
transformer_options = kwargs.get("transformer_options", {})
output = self.decode( output = self.decode(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
@@ -402,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length=output_length, output_length=output_length,
block_controlnet_hidden_states=block_controlnet_hidden_states, block_controlnet_hidden_states=block_controlnet_hidden_states,
controlnet_scale=controlnet_scale, controlnet_scale=controlnet_scale,
transformer_options=transformer_options,
) )
return output return output

View File

@@ -298,7 +298,8 @@ class Attention(nn.Module):
mask = None, mask = None,
context_mask = None, context_mask = None,
rotary_pos_emb = None, rotary_pos_emb = None,
causal = None causal = None,
transformer_options={},
): ):
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
@@ -363,7 +364,7 @@ class Attention(nn.Module):
heads_per_kv_head = h // kv_h heads_per_kv_head = h // kv_h
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
out = optimized_attention(q, k, v, h, skip_reshape=True) out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out = self.to_out(out) out = self.to_out(out)
if mask is not None: if mask is not None:
@@ -488,7 +489,8 @@ class TransformerBlock(nn.Module):
global_cond=None, global_cond=None,
mask = None, mask = None,
context_mask = None, context_mask = None,
rotary_pos_emb = None rotary_pos_emb = None,
transformer_options={}
): ):
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
@@ -498,12 +500,12 @@ class TransformerBlock(nn.Module):
residual = x residual = x
x = self.pre_norm(x) x = self.pre_norm(x)
x = x * (1 + scale_self) + shift_self x = x * (1 + scale_self) + shift_self
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb) x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
x = x * torch.sigmoid(1 - gate_self) x = x * torch.sigmoid(1 - gate_self)
x = x + residual x = x + residual
if context is not None: if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
if self.conformer is not None: if self.conformer is not None:
x = x + self.conformer(x) x = x + self.conformer(x)
@@ -517,10 +519,10 @@ class TransformerBlock(nn.Module):
x = x + residual x = x + residual
else: else:
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb) x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
if context is not None: if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
if self.conformer is not None: if self.conformer is not None:
x = x + self.conformer(x) x = x + self.conformer(x)
@@ -606,7 +608,8 @@ class ContinuousTransformer(nn.Module):
return_info = False, return_info = False,
**kwargs **kwargs
): ):
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {}) transformer_options = kwargs.get("transformer_options", {})
patches_replace = transformer_options.get("patches_replace", {})
batch, seq, device = *x.shape[:2], x.device batch, seq, device = *x.shape[:2], x.device
context = kwargs["context"] context = kwargs["context"]
@@ -645,13 +648,13 @@ class ContinuousTransformer(nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"]) out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"] x = out["img"]
else: else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context) x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
if return_info: if return_info:

View File

@@ -85,7 +85,7 @@ class SingleAttention(nn.Module):
) )
#@torch.compile() #@torch.compile()
def forward(self, c): def forward(self, c, transformer_options={}):
bsz, seqlen1, _ = c.shape bsz, seqlen1, _ = c.shape
@@ -95,7 +95,7 @@ class SingleAttention(nn.Module):
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim) v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
q, k = self.q_norm1(q), self.k_norm1(k) q, k = self.q_norm1(q), self.k_norm1(k)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
c = self.w1o(output) c = self.w1o(output)
return c return c
@@ -144,7 +144,7 @@ class DoubleAttention(nn.Module):
#@torch.compile() #@torch.compile()
def forward(self, c, x): def forward(self, c, x, transformer_options={}):
bsz, seqlen1, _ = c.shape bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape bsz, seqlen2, _ = x.shape
@@ -168,7 +168,7 @@ class DoubleAttention(nn.Module):
torch.cat([cv, xv], dim=1), torch.cat([cv, xv], dim=1),
) )
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
c, x = output.split([seqlen1, seqlen2], dim=1) c, x = output.split([seqlen1, seqlen2], dim=1)
c = self.w1o(c) c = self.w1o(c)
@@ -207,7 +207,7 @@ class MMDiTBlock(nn.Module):
self.is_last = is_last self.is_last = is_last
#@torch.compile() #@torch.compile()
def forward(self, c, x, global_cond, **kwargs): def forward(self, c, x, global_cond, transformer_options={}, **kwargs):
cres, xres = c, x cres, xres = c, x
@@ -225,7 +225,7 @@ class MMDiTBlock(nn.Module):
x = modulate(self.normX1(x), xshift_msa, xscale_msa) x = modulate(self.normX1(x), xshift_msa, xscale_msa)
# attention # attention
c, x = self.attn(c, x) c, x = self.attn(c, x, transformer_options=transformer_options)
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c) c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
@@ -255,13 +255,13 @@ class DiTBlock(nn.Module):
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations) self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
#@torch.compile() #@torch.compile()
def forward(self, cx, global_cond, **kwargs): def forward(self, cx, global_cond, transformer_options={}, **kwargs):
cxres = cx cxres = cx
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX( shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
global_cond global_cond
).chunk(6, dim=1) ).chunk(6, dim=1)
cx = modulate(self.norm1(cx), shift_msa, scale_msa) cx = modulate(self.norm1(cx), shift_msa, scale_msa)
cx = self.attn(cx) cx = self.attn(cx, transformer_options=transformer_options)
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx) cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp)) mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
cx = gate_mlp.unsqueeze(1) * mlpout cx = gate_mlp.unsqueeze(1) * mlpout
@@ -473,13 +473,14 @@ class MMDiT(nn.Module):
out = {} out = {}
out["txt"], out["img"] = layer(args["txt"], out["txt"], out["img"] = layer(args["txt"],
args["img"], args["img"],
args["vec"]) args["vec"],
transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
c = out["txt"] c = out["txt"]
x = out["img"] x = out["img"]
else: else:
c, x = layer(c, x, global_cond, **kwargs) c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)
if len(self.single_layers) > 0: if len(self.single_layers) > 0:
c_len = c.size(1) c_len = c.size(1)
@@ -488,13 +489,13 @@ class MMDiT(nn.Module):
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = layer(args["img"], args["vec"]) out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap}) out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
cx = out["img"] cx = out["img"]
else: else:
cx = layer(cx, global_cond, **kwargs) cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)
x = cx[:, c_len:] x = cx[:, c_len:]

View File

@@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module):
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
def forward(self, q, k, v): def forward(self, q, k, v, transformer_options={}):
q = self.to_q(q) q = self.to_q(q)
k = self.to_k(k) k = self.to_k(k)
v = self.to_v(v) v = self.to_v(v)
out = optimized_attention(q, k, v, self.heads) out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
return self.out_proj(out) return self.out_proj(out)
@@ -47,13 +47,13 @@ class Attention2D(nn.Module):
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations) self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device) # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
def forward(self, x, kv, self_attn=False): def forward(self, x, kv, self_attn=False, transformer_options={}):
orig_shape = x.shape orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn: if self_attn:
kv = torch.cat([x, kv], dim=1) kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0] # x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv) x = self.attn(x, kv, kv, transformer_options=transformer_options)
x = x.permute(0, 2, 1).view(*orig_shape) x = x.permute(0, 2, 1).view(*orig_shape)
return x return x
@@ -114,9 +114,9 @@ class AttnBlock(nn.Module):
operations.Linear(c_cond, c, dtype=dtype, device=device) operations.Linear(c_cond, c, dtype=dtype, device=device)
) )
def forward(self, x, kv): def forward(self, x, kv, transformer_options={}):
kv = self.kv_mapper(kv) kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
return x return x

View File

@@ -173,7 +173,7 @@ class StageB(nn.Module):
clip = self.clip_norm(clip) clip = self.clip_norm(clip)
return clip return clip
def _down_encode(self, x, r_embed, clip): def _down_encode(self, x, r_embed, clip, transformer_options={}):
level_outputs = [] level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group: for down_block, downscaler, repmap in block_group:
@@ -187,7 +187,7 @@ class StageB(nn.Module):
elif isinstance(block, AttnBlock) or ( elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)): AttnBlock)):
x = block(x, clip) x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or ( elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)): TimestepBlock)):
@@ -199,7 +199,7 @@ class StageB(nn.Module):
level_outputs.insert(0, x) level_outputs.insert(0, x)
return level_outputs return level_outputs
def _up_decode(self, level_outputs, r_embed, clip): def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
x = level_outputs[0] x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group): for i, (up_block, upscaler, repmap) in enumerate(block_group):
@@ -216,7 +216,7 @@ class StageB(nn.Module):
elif isinstance(block, AttnBlock) or ( elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)): AttnBlock)):
x = block(x, clip) x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or ( elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)): TimestepBlock)):
@@ -228,7 +228,7 @@ class StageB(nn.Module):
x = upscaler(x) x = upscaler(x)
return x return x
def forward(self, x, r, effnet, clip, pixels=None, **kwargs): def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
if pixels is None: if pixels is None:
pixels = x.new_zeros(x.size(0), 3, 8, 8) pixels = x.new_zeros(x.size(0), 3, 8, 8)
@@ -245,8 +245,8 @@ class StageB(nn.Module):
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True)) nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear', x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
align_corners=True) align_corners=True)
level_outputs = self._down_encode(x, r_embed, clip) level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
x = self._up_decode(level_outputs, r_embed, clip) x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
return self.clf(x) return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999): def update_weights_ema(self, src_model, beta=0.999):

View File

@@ -182,7 +182,7 @@ class StageC(nn.Module):
clip = self.clip_norm(clip) clip = self.clip_norm(clip)
return clip return clip
def _down_encode(self, x, r_embed, clip, cnet=None): def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}):
level_outputs = [] level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group: for down_block, downscaler, repmap in block_group:
@@ -201,7 +201,7 @@ class StageC(nn.Module):
elif isinstance(block, AttnBlock) or ( elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)): AttnBlock)):
x = block(x, clip) x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or ( elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)): TimestepBlock)):
@@ -213,7 +213,7 @@ class StageC(nn.Module):
level_outputs.insert(0, x) level_outputs.insert(0, x)
return level_outputs return level_outputs
def _up_decode(self, level_outputs, r_embed, clip, cnet=None): def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}):
x = level_outputs[0] x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group): for i, (up_block, upscaler, repmap) in enumerate(block_group):
@@ -235,7 +235,7 @@ class StageC(nn.Module):
elif isinstance(block, AttnBlock) or ( elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)): AttnBlock)):
x = block(x, clip) x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or ( elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)): TimestepBlock)):
@@ -247,7 +247,7 @@ class StageC(nn.Module):
x = upscaler(x) x = upscaler(x)
return x return x
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs): def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs):
# Process the conditioning embeddings # Process the conditioning embeddings
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
for c in self.t_conds: for c in self.t_conds:
@@ -262,8 +262,8 @@ class StageC(nn.Module):
# Model Blocks # Model Blocks
x = self.embedding(x) x = self.embedding(x)
level_outputs = self._down_encode(x, r_embed, clip, cnet) level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options)
x = self._up_decode(level_outputs, r_embed, clip, cnet) x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options)
return self.clf(x) return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999): def update_weights_ema(self, src_model, beta=0.999):

View File

@@ -76,7 +76,7 @@ class DoubleStreamBlock(nn.Module):
) )
self.flipped_img_txt = flipped_img_txt self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None): def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention # prepare image for attention
@@ -95,7 +95,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((txt_q, img_q), dim=2), attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2), torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2), torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask) pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
@@ -148,7 +148,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh") self.mlp_act = nn.GELU(approximate="tanh")
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor: def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
mod = vec mod = vec
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x)) x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@@ -157,7 +157,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
# compute attention # compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask) attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x.addcmul_(mod.gate, output) x.addcmul_(mod.gate, output)

View File

@@ -193,14 +193,16 @@ class Chroma(nn.Module):
txt=args["txt"], txt=args["txt"],
vec=args["vec"], vec=args["vec"],
pe=args["pe"], pe=args["pe"],
attn_mask=args.get("attn_mask")) attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out return out
out = blocks_replace[("double_block", i)]({"img": img, out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt, "txt": txt,
"vec": double_mod, "vec": double_mod,
"pe": pe, "pe": pe,
"attn_mask": attn_mask}, "attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap}) {"original_block": block_wrap})
txt = out["txt"] txt = out["txt"]
img = out["img"] img = out["img"]
@@ -209,7 +211,8 @@ class Chroma(nn.Module):
txt=txt, txt=txt,
vec=double_mod, vec=double_mod,
pe=pe, pe=pe,
attn_mask=attn_mask) attn_mask=attn_mask,
transformer_options=transformer_options)
if control is not None: # Controlnet if control is not None: # Controlnet
control_i = control.get("input") control_i = control.get("input")
@@ -229,17 +232,19 @@ class Chroma(nn.Module):
out["img"] = block(args["img"], out["img"] = block(args["img"],
vec=args["vec"], vec=args["vec"],
pe=args["pe"], pe=args["pe"],
attn_mask=args.get("attn_mask")) attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out return out
out = blocks_replace[("single_block", i)]({"img": img, out = blocks_replace[("single_block", i)]({"img": img,
"vec": single_mod, "vec": single_mod,
"pe": pe, "pe": pe,
"attn_mask": attn_mask}, "attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap}) {"original_block": block_wrap})
img = out["img"] img = out["img"]
else: else:
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask) img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet if control is not None: # Controlnet
control_o = control.get("output") control_o = control.get("output")

View File

@@ -176,6 +176,7 @@ class Attention(nn.Module):
context=None, context=None,
mask=None, mask=None,
rope_emb=None, rope_emb=None,
transformer_options={},
**kwargs, **kwargs,
): ):
""" """
@@ -184,7 +185,7 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
""" """
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True) out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options)
del q, k, v del q, k, v
out = rearrange(out, " b n s c -> s b (n c)") out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out) return self.to_out(out)
@@ -546,6 +547,7 @@ class VideoAttn(nn.Module):
context: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None,
crossattn_mask: Optional[torch.Tensor] = None, crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Forward pass for video attention. Forward pass for video attention.
@@ -571,6 +573,7 @@ class VideoAttn(nn.Module):
context_M_B_D, context_M_B_D,
crossattn_mask, crossattn_mask,
rope_emb=rope_emb_L_1_1_D, rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
) )
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W) x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
return x_T_H_W_B_D return x_T_H_W_B_D
@@ -665,6 +668,7 @@ class DITBuildingBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None, crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Forward pass for dynamically configured blocks with adaptive normalization. Forward pass for dynamically configured blocks with adaptive normalization.
@@ -702,6 +706,7 @@ class DITBuildingBlock(nn.Module):
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
context=None, context=None,
rope_emb_L_1_1_D=rope_emb_L_1_1_D, rope_emb_L_1_1_D=rope_emb_L_1_1_D,
transformer_options=transformer_options,
) )
elif self.block_type in ["cross_attn", "ca"]: elif self.block_type in ["cross_attn", "ca"]:
x = x + gate_1_1_1_B_D * self.block( x = x + gate_1_1_1_B_D * self.block(
@@ -709,6 +714,7 @@ class DITBuildingBlock(nn.Module):
context=crossattn_emb, context=crossattn_emb,
crossattn_mask=crossattn_mask, crossattn_mask=crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D, rope_emb_L_1_1_D=rope_emb_L_1_1_D,
transformer_options=transformer_options,
) )
else: else:
raise ValueError(f"Unknown block type: {self.block_type}") raise ValueError(f"Unknown block type: {self.block_type}")
@@ -784,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None, crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor: ) -> torch.Tensor:
for block in self.blocks: for block in self.blocks:
x = block( x = block(
@@ -793,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask, crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D, rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D, adaln_lora_B_3D=adaln_lora_B_3D,
transformer_options=transformer_options,
) )
return x return x

View File

@@ -520,6 +520,7 @@ class GeneralDIT(nn.Module):
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
transformer_options = kwargs.get("transformer_options", {})
for _, block in self.blocks.items(): for _, block in self.blocks.items():
assert ( assert (
self.blocks["block0"].x_format == block.x_format self.blocks["block0"].x_format == block.x_format
@@ -534,6 +535,7 @@ class GeneralDIT(nn.Module):
crossattn_mask, crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D, rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D, adaln_lora_B_3D=adaln_lora_B_3D,
transformer_options=transformer_options,
) )
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")

View File

@@ -44,7 +44,7 @@ class GPT2FeedForward(nn.Module):
return x return x
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor: def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
"""Computes multi-head attention using PyTorch's native implementation. """Computes multi-head attention using PyTorch's native implementation.
This function provides a PyTorch backend alternative to Transformer Engine's attention operation. This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
@@ -71,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True) return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options)
class Attention(nn.Module): class Attention(nn.Module):
@@ -180,8 +180,8 @@ class Attention(nn.Module):
return q, k, v return q, k, v
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
result = self.attn_op(q, k, v) # [B, S, H, D] result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
return self.output_dropout(self.output_proj(result)) return self.output_dropout(self.output_proj(result))
def forward( def forward(
@@ -189,6 +189,7 @@ class Attention(nn.Module):
x: torch.Tensor, x: torch.Tensor,
context: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None, rope_emb: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@@ -196,7 +197,7 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
""" """
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
return self.compute_attention(q, k, v) return self.compute_attention(q, k, v, transformer_options=transformer_options)
class Timesteps(nn.Module): class Timesteps(nn.Module):
@@ -459,6 +460,7 @@ class Block(nn.Module):
rope_emb_L_1_1_D: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor: ) -> torch.Tensor:
if extra_per_block_pos_emb is not None: if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
@@ -512,6 +514,7 @@ class Block(nn.Module):
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
None, None,
rope_emb=rope_emb_L_1_1_D, rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
), ),
"b (t h w) d -> b t h w d", "b (t h w) d -> b t h w d",
t=T, t=T,
@@ -525,6 +528,7 @@ class Block(nn.Module):
layer_norm_cross_attn: Callable, layer_norm_cross_attn: Callable,
_scale_cross_attn_B_T_1_1_D: torch.Tensor, _scale_cross_attn_B_T_1_1_D: torch.Tensor,
_shift_cross_attn_B_T_1_1_D: torch.Tensor, _shift_cross_attn_B_T_1_1_D: torch.Tensor,
transformer_options: Optional[dict] = {},
) -> torch.Tensor: ) -> torch.Tensor:
_normalized_x_B_T_H_W_D = _fn( _normalized_x_B_T_H_W_D = _fn(
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
@@ -534,6 +538,7 @@ class Block(nn.Module):
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
crossattn_emb, crossattn_emb,
rope_emb=rope_emb_L_1_1_D, rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
), ),
"b (t h w) d -> b t h w d", "b (t h w) d -> b t h w d",
t=T, t=T,
@@ -547,6 +552,7 @@ class Block(nn.Module):
self.layer_norm_cross_attn, self.layer_norm_cross_attn,
scale_cross_attn_B_T_1_1_D, scale_cross_attn_B_T_1_1_D,
shift_cross_attn_B_T_1_1_D, shift_cross_attn_B_T_1_1_D,
transformer_options=transformer_options,
) )
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
@@ -865,6 +871,7 @@ class MiniTrainDIT(nn.Module):
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0), "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
"adaln_lora_B_T_3D": adaln_lora_B_T_3D, "adaln_lora_B_T_3D": adaln_lora_B_T_3D,
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
"transformer_options": kwargs.get("transformer_options", {}),
} }
for block in self.blocks: for block in self.blocks:
x_B_T_H_W_D = block( x_B_T_H_W_D = block(

View File

@@ -159,7 +159,7 @@ class DoubleStreamBlock(nn.Module):
) )
self.flipped_img_txt = flipped_img_txt self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None): def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
img_mod1, img_mod2 = self.img_mod(vec) img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec)
@@ -182,7 +182,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((img_q, txt_q), dim=2), attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2), torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2), torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask) pe=pe, mask=attn_mask, transformer_options=transformer_options)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:] img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else: else:
@@ -190,7 +190,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((txt_q, img_q), dim=2), attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2), torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2), torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask) pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
@@ -244,7 +244,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh") self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor: def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
mod, _ = self.modulation(vec) mod, _ = self.modulation(vec)
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@@ -252,7 +252,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
# compute attention # compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask) attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += apply_mod(output, mod.gate, None, modulation_dims) x += apply_mod(output, mod.gate, None, modulation_dims)

View File

@@ -6,7 +6,7 @@ from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
q_shape = q.shape q_shape = q.shape
k_shape = k.shape k_shape = k.shape
@@ -17,7 +17,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1] heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x return x

View File

@@ -135,14 +135,16 @@ class Flux(nn.Module):
txt=args["txt"], txt=args["txt"],
vec=args["vec"], vec=args["vec"],
pe=args["pe"], pe=args["pe"],
attn_mask=args.get("attn_mask")) attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out return out
out = blocks_replace[("double_block", i)]({"img": img, out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt, "txt": txt,
"vec": vec, "vec": vec,
"pe": pe, "pe": pe,
"attn_mask": attn_mask}, "attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap}) {"original_block": block_wrap})
txt = out["txt"] txt = out["txt"]
img = out["img"] img = out["img"]
@@ -151,7 +153,8 @@ class Flux(nn.Module):
txt=txt, txt=txt,
vec=vec, vec=vec,
pe=pe, pe=pe,
attn_mask=attn_mask) attn_mask=attn_mask,
transformer_options=transformer_options)
if control is not None: # Controlnet if control is not None: # Controlnet
control_i = control.get("input") control_i = control.get("input")
@@ -172,17 +175,19 @@ class Flux(nn.Module):
out["img"] = block(args["img"], out["img"] = block(args["img"],
vec=args["vec"], vec=args["vec"],
pe=args["pe"], pe=args["pe"],
attn_mask=args.get("attn_mask")) attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out return out
out = blocks_replace[("single_block", i)]({"img": img, out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec, "vec": vec,
"pe": pe, "pe": pe,
"attn_mask": attn_mask}, "attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap}) {"original_block": block_wrap})
img = out["img"] img = out["img"]
else: else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet if control is not None: # Controlnet
control_o = control.get("output") control_o = control.get("output")

View File

@@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module):
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
crop_y, crop_y,
transformer_options={},
**rope_rotation, **rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
rope_cos = rope_rotation.get("rope_cos") rope_cos = rope_rotation.get("rope_cos")
@@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module):
xy = optimized_attention(q, xy = optimized_attention(q,
k, k,
v, self.num_heads, skip_reshape=True) v, self.num_heads, skip_reshape=True, transformer_options=transformer_options)
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1) x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
x = self.proj_x(x) x = self.proj_x(x)
@@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module):
x: torch.Tensor, x: torch.Tensor,
c: torch.Tensor, c: torch.Tensor,
y: torch.Tensor, y: torch.Tensor,
transformer_options={},
**attn_kwargs, **attn_kwargs,
): ):
"""Forward pass of a block. """Forward pass of a block.
@@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module):
y, y,
scale_x=scale_msa_x, scale_x=scale_msa_x,
scale_y=scale_msa_y, scale_y=scale_msa_y,
transformer_options=transformer_options,
**attn_kwargs, **attn_kwargs,
) )
@@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module):
args["txt"], args["txt"],
rope_cos=args["rope_cos"], rope_cos=args["rope_cos"],
rope_sin=args["rope_sin"], rope_sin=args["rope_sin"],
crop_y=args["num_tokens"] crop_y=args["num_tokens"],
transformer_options=args["transformer_options"]
) )
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap})
y_feat = out["txt"] y_feat = out["txt"]
x = out["img"] x = out["img"]
else: else:
@@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module):
rope_cos=rope_cos, rope_cos=rope_cos,
rope_sin=rope_sin, rope_sin=rope_sin,
crop_y=num_tokens, crop_y=num_tokens,
transformer_options=transformer_options,
) # (B, M, D), (B, L, D) ) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features. del y_feat # Final layers don't use dense text features.

View File

@@ -72,8 +72,8 @@ class TimestepEmbed(nn.Module):
return t_emb return t_emb
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}):
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2]) return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options)
class HiDreamAttnProcessor_flashattn: class HiDreamAttnProcessor_flashattn:
@@ -86,6 +86,7 @@ class HiDreamAttnProcessor_flashattn:
image_tokens_masks: Optional[torch.FloatTensor] = None, image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None, rope: torch.FloatTensor = None,
transformer_options={},
*args, *args,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
@@ -133,7 +134,7 @@ class HiDreamAttnProcessor_flashattn:
query = torch.cat([query_1, query_2], dim=-1) query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1) key = torch.cat([key_1, key_2], dim=-1)
hidden_states = attention(query, key, value) hidden_states = attention(query, key, value, transformer_options=transformer_options)
if not attn.single: if not attn.single:
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
@@ -199,6 +200,7 @@ class HiDreamAttention(nn.Module):
image_tokens_masks: torch.FloatTensor = None, image_tokens_masks: torch.FloatTensor = None,
norm_text_tokens: torch.FloatTensor = None, norm_text_tokens: torch.FloatTensor = None,
rope: torch.FloatTensor = None, rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.Tensor: ) -> torch.Tensor:
return self.processor( return self.processor(
self, self,
@@ -206,6 +208,7 @@ class HiDreamAttention(nn.Module):
image_tokens_masks = image_tokens_masks, image_tokens_masks = image_tokens_masks,
text_tokens = norm_text_tokens, text_tokens = norm_text_tokens,
rope = rope, rope = rope,
transformer_options=transformer_options,
) )
@@ -406,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None, adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None, rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor: ) -> torch.FloatTensor:
wtype = image_tokens.dtype wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \ shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
@@ -419,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
norm_image_tokens, norm_image_tokens,
image_tokens_masks, image_tokens_masks,
rope = rope, rope = rope,
transformer_options=transformer_options,
) )
image_tokens = gate_msa_i * attn_output_i + image_tokens image_tokens = gate_msa_i * attn_output_i + image_tokens
@@ -483,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None, adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None, rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor: ) -> torch.FloatTensor:
wtype = image_tokens.dtype wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
@@ -500,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module):
image_tokens_masks, image_tokens_masks,
norm_text_tokens, norm_text_tokens,
rope = rope, rope = rope,
transformer_options=transformer_options,
) )
image_tokens = gate_msa_i * attn_output_i + image_tokens image_tokens = gate_msa_i * attn_output_i + image_tokens
@@ -550,6 +556,7 @@ class HiDreamImageBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: torch.FloatTensor = None, adaln_input: torch.FloatTensor = None,
rope: torch.FloatTensor = None, rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor: ) -> torch.FloatTensor:
return self.block( return self.block(
image_tokens, image_tokens,
@@ -557,6 +564,7 @@ class HiDreamImageBlock(nn.Module):
text_tokens, text_tokens,
adaln_input, adaln_input,
rope, rope,
transformer_options=transformer_options,
) )
@@ -786,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
text_tokens = cur_encoder_hidden_states, text_tokens = cur_encoder_hidden_states,
adaln_input = adaln_input, adaln_input = adaln_input,
rope = rope, rope = rope,
transformer_options=transformer_options,
) )
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1 block_id += 1
@@ -809,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
text_tokens=None, text_tokens=None,
adaln_input=adaln_input, adaln_input=adaln_input,
rope=rope, rope=rope,
transformer_options=transformer_options,
) )
hidden_states = hidden_states[:, :hidden_states_seq_len] hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1 block_id += 1

View File

@@ -99,14 +99,16 @@ class Hunyuan3Dv2(nn.Module):
txt=args["txt"], txt=args["txt"],
vec=args["vec"], vec=args["vec"],
pe=args["pe"], pe=args["pe"],
attn_mask=args.get("attn_mask")) attn_mask=args.get("attn_mask"),
transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": img, out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt, "txt": txt,
"vec": vec, "vec": vec,
"pe": pe, "pe": pe,
"attn_mask": attn_mask}, "attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap}) {"original_block": block_wrap})
txt = out["txt"] txt = out["txt"]
img = out["img"] img = out["img"]
@@ -115,7 +117,8 @@ class Hunyuan3Dv2(nn.Module):
txt=txt, txt=txt,
vec=vec, vec=vec,
pe=pe, pe=pe,
attn_mask=attn_mask) attn_mask=attn_mask,
transformer_options=transformer_options)
img = torch.cat((txt, img), 1) img = torch.cat((txt, img), 1)
@@ -126,17 +129,19 @@ class Hunyuan3Dv2(nn.Module):
out["img"] = block(args["img"], out["img"] = block(args["img"],
vec=args["vec"], vec=args["vec"],
pe=args["pe"], pe=args["pe"],
attn_mask=args.get("attn_mask")) attn_mask=args.get("attn_mask"),
transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("single_block", i)]({"img": img, out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec, "vec": vec,
"pe": pe, "pe": pe,
"attn_mask": attn_mask}, "attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap}) {"original_block": block_wrap})
img = out["img"] img = out["img"]
else: else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
img = img[:, txt.shape[1]:, ...] img = img[:, txt.shape[1]:, ...]
img = self.final_layer(img, vec) img = self.final_layer(img, vec)

View File

@@ -78,13 +78,13 @@ class TokenRefinerBlock(nn.Module):
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
) )
def forward(self, x, c, mask): def forward(self, x, c, mask, transformer_options={}):
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1) mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x) norm_x = self.norm1(x)
qkv = self.self_attn.qkv(norm_x) qkv = self.self_attn.qkv(norm_x)
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True) attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options)
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1) x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1) x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
@@ -115,14 +115,14 @@ class IndividualTokenRefiner(nn.Module):
] ]
) )
def forward(self, x, c, mask): def forward(self, x, c, mask, transformer_options={}):
m = None m = None
if mask is not None: if mask is not None:
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1) m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
m = m + m.transpose(2, 3) m = m + m.transpose(2, 3)
for block in self.blocks: for block in self.blocks:
x = block(x, c, m) x = block(x, c, m, transformer_options=transformer_options)
return x return x
@@ -150,6 +150,7 @@ class TokenRefiner(nn.Module):
x, x,
timesteps, timesteps,
mask, mask,
transformer_options={},
): ):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype)) t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1) # m = mask.float().unsqueeze(-1)
@@ -158,7 +159,7 @@ class TokenRefiner(nn.Module):
c = t + self.c_embedder(c.to(x.dtype)) c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x) x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, mask) x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options)
return x return x
class HunyuanVideo(nn.Module): class HunyuanVideo(nn.Module):
@@ -267,7 +268,7 @@ class HunyuanVideo(nn.Module):
if txt_mask is not None and not torch.is_floating_point(txt_mask): if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
txt = self.txt_in(txt, timesteps, txt_mask) txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
ids = torch.cat((img_ids, txt_ids), dim=1) ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids) pe = self.pe_embedder(ids)
@@ -285,14 +286,14 @@ class HunyuanVideo(nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"]) out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap})
txt = out["txt"] txt = out["txt"]
img = out["img"] img = out["img"]
else: else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options)
if control is not None: # Controlnet if control is not None: # Controlnet
control_i = control.get("input") control_i = control.get("input")
@@ -307,13 +308,13 @@ class HunyuanVideo(nn.Module):
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"]) out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap}) out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap})
img = out["img"] img = out["img"]
else: else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options)
if control is not None: # Controlnet if control is not None: # Controlnet
control_o = control.get("output") control_o = control.get("output")

View File

@@ -271,7 +271,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None, pe=None): def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
q = self.to_q(x) q = self.to_q(x)
context = x if context is None else context context = x if context is None else context
k = self.to_k(context) k = self.to_k(context)
@@ -285,9 +285,9 @@ class CrossAttention(nn.Module):
k = apply_rotary_emb(k, pe) k = apply_rotary_emb(k, pe)
if mask is None: if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else: else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
return self.to_out(out) return self.to_out(out)
@@ -303,12 +303,12 @@ class BasicTransformerBlock(nn.Module):
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None): def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
x += self.attn2(x, context=context, mask=attention_mask) x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp x += self.ff(y) * gate_mlp
@@ -479,10 +479,10 @@ class LTXVModel(torch.nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"]) out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"] x = out["img"]
else: else:
x = block( x = block(
@@ -490,7 +490,8 @@ class LTXVModel(torch.nn.Module):
context=context, context=context,
attention_mask=attention_mask, attention_mask=attention_mask,
timestep=timestep, timestep=timestep,
pe=pe pe=pe,
transformer_options=transformer_options,
) )
# 3. Output # 3. Output

View File

@@ -104,6 +104,7 @@ class JointAttention(nn.Module):
x: torch.Tensor, x: torch.Tensor,
x_mask: torch.Tensor, x_mask: torch.Tensor,
freqs_cis: torch.Tensor, freqs_cis: torch.Tensor,
transformer_options={},
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@@ -140,7 +141,7 @@ class JointAttention(nn.Module):
if n_rep >= 1: if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True) output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
return self.out(output) return self.out(output)
@@ -268,6 +269,7 @@ class JointTransformerBlock(nn.Module):
x_mask: torch.Tensor, x_mask: torch.Tensor,
freqs_cis: torch.Tensor, freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor]=None, adaln_input: Optional[torch.Tensor]=None,
transformer_options={},
): ):
""" """
Perform a forward pass through the TransformerBlock. Perform a forward pass through the TransformerBlock.
@@ -290,6 +292,7 @@ class JointTransformerBlock(nn.Module):
modulate(self.attention_norm1(x), scale_msa), modulate(self.attention_norm1(x), scale_msa),
x_mask, x_mask,
freqs_cis, freqs_cis,
transformer_options=transformer_options,
) )
) )
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
@@ -304,6 +307,7 @@ class JointTransformerBlock(nn.Module):
self.attention_norm1(x), self.attention_norm1(x),
x_mask, x_mask,
freqs_cis, freqs_cis,
transformer_options=transformer_options,
) )
) )
x = x + self.ffn_norm2( x = x + self.ffn_norm2(
@@ -494,7 +498,7 @@ class NextDiT(nn.Module):
return imgs return imgs
def patchify_and_embed( def patchify_and_embed(
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = len(x) bsz = len(x)
pH = pW = self.patch_size pH = pW = self.patch_size
@@ -554,7 +558,7 @@ class NextDiT(nn.Module):
# refine context # refine context
for layer in self.context_refiner: for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
# refine image # refine image
flat_x = [] flat_x = []
@@ -573,7 +577,7 @@ class NextDiT(nn.Module):
padded_img_embed = self.x_embedder(padded_img_embed) padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1) padded_img_mask = padded_img_mask.unsqueeze(1)
for layer in self.noise_refiner: for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
if cap_mask is not None: if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device) mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
@@ -616,12 +620,13 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
transformer_options = kwargs.get("transformer_options", {})
x_is_tensor = isinstance(x, torch.Tensor) x_is_tensor = isinstance(x, torch.Tensor)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens) x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(x.device) freqs_cis = freqs_cis.to(x.device)
for layer in self.layers: for layer in self.layers:
x = layer(x, mask, freqs_cis, adaln_input) x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
x = self.final_layer(x, adaln_input) x = self.final_layer(x, adaln_input)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]

View File

@@ -5,8 +5,9 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from typing import Optional from typing import Optional, Any, Callable, Union
import logging import logging
import functools
from .diffusionmodules.util import AlphaBlender, timestep_embedding from .diffusionmodules.util import AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
@@ -17,23 +18,45 @@ if model_management.xformers_enabled():
import xformers import xformers
import xformers.ops import xformers.ops
if model_management.sage_attention_enabled(): SAGE_ATTENTION_IS_AVAILABLE = False
try: try:
from sageattention import sageattn from sageattention import sageattn
except ModuleNotFoundError as e: SAGE_ATTENTION_IS_AVAILABLE = True
except ModuleNotFoundError as e:
if model_management.sage_attention_enabled():
if e.name == "sageattention": if e.name == "sageattention":
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
else: else:
raise e raise e
exit(-1) exit(-1)
if model_management.flash_attention_enabled(): FLASH_ATTENTION_IS_AVAILABLE = False
try: try:
from flash_attn import flash_attn_func from flash_attn import flash_attn_func
except ModuleNotFoundError: FLASH_ATTENTION_IS_AVAILABLE = True
except ModuleNotFoundError:
if model_management.flash_attention_enabled():
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1) exit(-1)
REGISTERED_ATTENTION_FUNCTIONS = {}
def register_attention_function(name: str, func: Callable):
# avoid replacing existing functions
if name not in REGISTERED_ATTENTION_FUNCTIONS:
REGISTERED_ATTENTION_FUNCTIONS[name] = func
else:
logging.warning(f"Attention function {name} already registered, skipping registration.")
def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]:
if name == "optimized":
return optimized_attention
elif name not in REGISTERED_ATTENTION_FUNCTIONS:
if default is ...:
raise KeyError(f"Attention function {name} not found.")
else:
return default
return REGISTERED_ATTENTION_FUNCTIONS[name]
from comfy.cli_args import args from comfy.cli_args import args
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@@ -91,7 +114,27 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None): def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
def wrap_attn(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
remove_attn_wrapper_key = False
try:
if "_inside_attn_wrapper" not in kwargs:
transformer_options = kwargs.get("transformer_options", None)
remove_attn_wrapper_key = True
kwargs["_inside_attn_wrapper"] = True
if transformer_options is not None:
if "optimized_attention_override" in transformer_options:
return transformer_options["optimized_attention_override"](func, *args, **kwargs)
return func(*args, **kwargs)
finally:
if remove_attn_wrapper_key:
del kwargs["_inside_attn_wrapper"]
return wrapper
@wrap_attn
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, q.dtype) attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape: if skip_reshape:
@@ -159,8 +202,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
) )
return out return out
@wrap_attn
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, query.dtype) attn_precision = get_attn_precision(attn_precision, query.dtype)
if skip_reshape: if skip_reshape:
@@ -230,7 +273,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): @wrap_attn
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, q.dtype) attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape: if skip_reshape:
@@ -359,7 +403,8 @@ try:
except: except:
pass pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): @wrap_attn
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
b = q.shape[0] b = q.shape[0]
dim_head = q.shape[-1] dim_head = q.shape[-1]
# check to make sure xformers isn't broken # check to make sure xformers isn't broken
@@ -374,7 +419,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True disabled_xformers = True
if disabled_xformers: if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
if skip_reshape: if skip_reshape:
# b h k d -> b k h d # b h k d -> b k h d
@@ -427,8 +472,8 @@ else:
#TODO: other GPUs ? #TODO: other GPUs ?
SDP_BATCH_LIMIT = 2**31 SDP_BATCH_LIMIT = 2**31
@wrap_attn
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
else: else:
@@ -470,8 +515,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out return out
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
tensor_layout = "HND" tensor_layout = "HND"
@@ -501,7 +546,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.transpose(1, 2), lambda t: t.transpose(1, 2),
(q, k, v), (q, k, v),
) )
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape) return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
if tensor_layout == "HND": if tensor_layout == "HND":
if not skip_output_reshape: if not skip_output_reshape:
@@ -534,8 +579,8 @@ except AttributeError as error:
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
@wrap_attn
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
else: else:
@@ -597,6 +642,19 @@ else:
optimized_attention_masked = optimized_attention optimized_attention_masked = optimized_attention
# register core-supported attention functions
if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage)
if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash)
if model_management.xformers_enabled():
register_attention_function("xformers", attention_xformers)
register_attention_function("pytorch", attention_pytorch)
register_attention_function("sub_quad", attention_sub_quad)
register_attention_function("split", attention_split)
def optimized_attention_for_device(device, mask=False, small_input=False): def optimized_attention_for_device(device, mask=False, small_input=False):
if small_input: if small_input:
if model_management.pytorch_attention_enabled(): if model_management.pytorch_attention_enabled():
@@ -629,7 +687,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, value=None, mask=None): def forward(self, x, context=None, value=None, mask=None, transformer_options={}):
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
k = self.to_k(context) k = self.to_k(context)
@@ -640,9 +698,9 @@ class CrossAttention(nn.Module):
v = self.to_v(context) v = self.to_v(context)
if mask is None: if mask is None:
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else: else:
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
return self.to_out(out) return self.to_out(out)
@@ -746,7 +804,7 @@ class BasicTransformerBlock(nn.Module):
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
n = self.attn1.to_out(n) n = self.attn1.to_out(n)
else: else:
n = self.attn1(n, context=context_attn1, value=value_attn1) n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options)
if "attn1_output_patch" in transformer_patches: if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"] patch = transformer_patches["attn1_output_patch"]
@@ -786,7 +844,7 @@ class BasicTransformerBlock(nn.Module):
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n) n = self.attn2.to_out(n)
else: else:
n = self.attn2(n, context=context_attn2, value=value_attn2) n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options)
if "attn2_output_patch" in transformer_patches: if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"] patch = transformer_patches["attn2_output_patch"]
@@ -1017,7 +1075,7 @@ class SpatialVideoTransformer(SpatialTransformer):
B, S, C = x_mix.shape B, S, C = x_mix.shape
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps) x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options x_mix = mix_block(x_mix, context=time_context, transformer_options=transformer_options)
x_mix = rearrange( x_mix = rearrange(
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
) )

View File

@@ -606,7 +606,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
return _block_mixing(*args, **kwargs) return _block_mixing(*args, **kwargs)
def _block_mixing(context, x, context_block, x_block, c): def _block_mixing(context, x, context_block, x_block, c, transformer_options={}):
context_qkv, context_intermediates = context_block.pre_attention(context, c) context_qkv, context_intermediates = context_block.pre_attention(context, c)
if x_block.x_block_self_attn: if x_block.x_block_self_attn:
@@ -622,6 +622,7 @@ def _block_mixing(context, x, context_block, x_block, c):
attn = optimized_attention( attn = optimized_attention(
qkv[0], qkv[1], qkv[2], qkv[0], qkv[1], qkv[2],
heads=x_block.attn.num_heads, heads=x_block.attn.num_heads,
transformer_options=transformer_options,
) )
context_attn, x_attn = ( context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]], attn[:, : context_qkv[0].shape[1]],
@@ -637,6 +638,7 @@ def _block_mixing(context, x, context_block, x_block, c):
attn2 = optimized_attention( attn2 = optimized_attention(
x_qkv2[0], x_qkv2[1], x_qkv2[2], x_qkv2[0], x_qkv2[1], x_qkv2[2],
heads=x_block.attn2.num_heads, heads=x_block.attn2.num_heads,
transformer_options=transformer_options,
) )
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates) x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
else: else:
@@ -958,10 +960,10 @@ class MMDiT(nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"]) out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap})
context = out["txt"] context = out["txt"]
x = out["img"] x = out["img"]
else: else:
@@ -970,6 +972,7 @@ class MMDiT(nn.Module):
x, x,
c=c_mod, c=c_mod,
use_checkpoint=self.use_checkpoint, use_checkpoint=self.use_checkpoint,
transformer_options=transformer_options,
) )
if control is not None: if control is not None:
control_o = control.get("output") control_o = control.get("output")

View File

@@ -120,7 +120,7 @@ class Attention(nn.Module):
nn.Dropout(0.0) nn.Dropout(0.0)
) )
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states) query = self.to_q(hidden_states)
@@ -146,7 +146,7 @@ class Attention(nn.Module):
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1) key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1) value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True) hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
hidden_states = self.to_out[0](hidden_states) hidden_states = self.to_out[0](hidden_states)
return hidden_states return hidden_states
@@ -182,16 +182,16 @@ class OmniGen2TransformerBlock(nn.Module):
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
if self.modulation: if self.modulation:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb) attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
else: else:
norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb) attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
hidden_states = hidden_states + self.norm2(attn_output) hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output) hidden_states = hidden_states + self.ffn_norm2(mlp_output)
@@ -390,7 +390,7 @@ class OmniGen2Transformer2DModel(nn.Module):
ref_img_sizes, img_sizes, ref_img_sizes, img_sizes,
) )
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb): def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, transformer_options={}):
batch_size = len(hidden_states) batch_size = len(hidden_states)
hidden_states = self.x_embedder(hidden_states) hidden_states = self.x_embedder(hidden_states)
@@ -405,17 +405,17 @@ class OmniGen2Transformer2DModel(nn.Module):
shift += ref_img_len shift += ref_img_len
for layer in self.noise_refiner: for layer in self.noise_refiner:
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb, transformer_options=transformer_options)
if ref_image_hidden_states is not None: if ref_image_hidden_states is not None:
for layer in self.ref_image_refiner: for layer in self.ref_image_refiner:
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb) ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb, transformer_options=transformer_options)
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1) hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
return hidden_states return hidden_states
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs): def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs):
B, C, H, W = x.shape B, C, H, W = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
_, _, H_padded, W_padded = hidden_states.shape _, _, H_padded, W_padded = hidden_states.shape
@@ -444,7 +444,7 @@ class OmniGen2Transformer2DModel(nn.Module):
) )
for layer in self.context_refiner: for layer in self.context_refiner:
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb) text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options)
img_len = hidden_states.shape[1] img_len = hidden_states.shape[1]
combined_img_hidden_states = self.img_patch_embed_and_refine( combined_img_hidden_states = self.img_patch_embed_and_refine(
@@ -453,13 +453,14 @@ class OmniGen2Transformer2DModel(nn.Module):
noise_rotary_emb, ref_img_rotary_emb, noise_rotary_emb, ref_img_rotary_emb,
l_effective_ref_img_len, l_effective_img_len, l_effective_ref_img_len, l_effective_img_len,
temb, temb,
transformer_options=transformer_options,
) )
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1) hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
attention_mask = None attention_mask = None
for layer in self.layers: for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb, transformer_options=transformer_options)
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)

View File

@@ -132,6 +132,7 @@ class Attention(nn.Module):
encoder_hidden_states_mask: torch.FloatTensor = None, encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
seq_txt = encoder_hidden_states.shape[1] seq_txt = encoder_hidden_states.shape[1]
@@ -159,7 +160,7 @@ class Attention(nn.Module):
joint_key = joint_key.flatten(start_dim=2) joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2) joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask) joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
txt_attn_output = joint_hidden_states[:, :seq_txt, :] txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :] img_attn_output = joint_hidden_states[:, seq_txt:, :]
@@ -226,6 +227,7 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states_mask: torch.Tensor, encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb) img_mod_params = self.img_mod(temb)
txt_mod_params = self.txt_mod(temb) txt_mod_params = self.txt_mod(temb)
@@ -242,6 +244,7 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states=txt_modulated, encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask, encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
) )
hidden_states = hidden_states + img_gate1 * img_attn_output hidden_states = hidden_states + img_gate1 * img_attn_output
@@ -434,9 +437,9 @@ class QwenImageTransformer2DModel(nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"]) out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"] hidden_states = out["img"]
encoder_hidden_states = out["txt"] encoder_hidden_states = out["txt"]
else: else:
@@ -446,11 +449,12 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask, encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
) )
if "double_block" in patches: if "double_block" in patches:
for p in patches["double_block"]: for p in patches["double_block"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i}) out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
hidden_states = out["img"] hidden_states = out["img"]
encoder_hidden_states = out["txt"] encoder_hidden_states = out["txt"]

View File

@@ -52,7 +52,7 @@ class WanSelfAttention(nn.Module):
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, freqs): def forward(self, x, freqs, transformer_options={}):
r""" r"""
Args: Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads] x(Tensor): Shape [B, L, num_heads, C / num_heads]
@@ -75,6 +75,7 @@ class WanSelfAttention(nn.Module):
k.view(b, s, n * d), k.view(b, s, n * d),
v, v,
heads=self.num_heads, heads=self.num_heads,
transformer_options=transformer_options,
) )
x = self.o(x) x = self.o(x)
@@ -83,7 +84,7 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention): class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context, **kwargs): def forward(self, x, context, transformer_options={}, **kwargs):
r""" r"""
Args: Args:
x(Tensor): Shape [B, L1, C] x(Tensor): Shape [B, L1, C]
@@ -95,7 +96,7 @@ class WanT2VCrossAttention(WanSelfAttention):
v = self.v(context) v = self.v(context)
# compute attention # compute attention
x = optimized_attention(q, k, v, heads=self.num_heads) x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
x = self.o(x) x = self.o(x)
return x return x
@@ -116,7 +117,7 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, ))) # self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, context, context_img_len): def forward(self, x, context, context_img_len, transformer_options={}):
r""" r"""
Args: Args:
x(Tensor): Shape [B, L1, C] x(Tensor): Shape [B, L1, C]
@@ -131,9 +132,9 @@ class WanI2VCrossAttention(WanSelfAttention):
v = self.v(context) v = self.v(context)
k_img = self.norm_k_img(self.k_img(context_img)) k_img = self.norm_k_img(self.k_img(context_img))
v_img = self.v_img(context_img) v_img = self.v_img(context_img)
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads) img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
# compute attention # compute attention
x = optimized_attention(q, k, v, heads=self.num_heads) x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
# output # output
x = x + img_x x = x + img_x
@@ -206,6 +207,7 @@ class WanAttentionBlock(nn.Module):
freqs, freqs,
context, context,
context_img_len=257, context_img_len=257,
transformer_options={},
): ):
r""" r"""
Args: Args:
@@ -224,12 +226,12 @@ class WanAttentionBlock(nn.Module):
# self-attention # self-attention
y = self.self_attn( y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs) freqs, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x)) x = torch.addcmul(x, y, repeat_e(e[2], x))
# cross-attention & ffn # cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x)) x = torch.addcmul(x, y, repeat_e(e[5], x))
return x return x
@@ -559,12 +561,12 @@ class WanModel(torch.nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"] x = out["img"]
else: else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# head # head
x = self.head(x, e) x = self.head(x, e)
@@ -742,17 +744,17 @@ class VaceWanModel(WanModel):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"] x = out["img"]
else: else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
ii = self.vace_layers_mapping.get(i, None) ii = self.vace_layers_mapping.get(i, None)
if ii is not None: if ii is not None:
for iii in range(len(c)): for iii in range(len(c)):
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
x += c_skip * vace_strength[iii] x += c_skip * vace_strength[iii]
del c_skip del c_skip
# head # head
@@ -841,12 +843,12 @@ class CameraWanModel(WanModel):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"] x = out["img"]
else: else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# head # head
x = self.head(x, e) x = self.head(x, e)