Made LTX work with optimized_attention_override

This commit is contained in:
Jedrzej Kosinski
2025-08-28 20:42:22 -07:00
parent 61b5c5fc75
commit 2cda45d1b4

View File

@@ -271,7 +271,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None, pe=None): def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
q = self.to_q(x) q = self.to_q(x)
context = x if context is None else context context = x if context is None else context
k = self.to_k(context) k = self.to_k(context)
@@ -285,9 +285,9 @@ class CrossAttention(nn.Module):
k = apply_rotary_emb(k, pe) k = apply_rotary_emb(k, pe)
if mask is None: if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else: else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
return self.to_out(out) return self.to_out(out)
@@ -303,12 +303,12 @@ class BasicTransformerBlock(nn.Module):
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None): def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
x += self.attn2(x, context=context, mask=attention_mask) x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp x += self.ff(y) * gate_mlp
@@ -479,10 +479,10 @@ class LTXVModel(torch.nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"]) out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"] x = out["img"]
else: else:
x = block( x = block(
@@ -490,7 +490,8 @@ class LTXVModel(torch.nn.Module):
context=context, context=context,
attention_mask=attention_mask, attention_mask=attention_mask,
timestep=timestep, timestep=timestep,
pe=pe pe=pe,
transformer_options=transformer_options,
) )
# 3. Output # 3. Output