diff --git a/comfy/ldm/ace/attention.py b/comfy/ldm/ace/attention.py index f20a01669..670eb9783 100644 --- a/comfy/ldm/ace/attention.py +++ b/comfy/ldm/ace/attention.py @@ -133,6 +133,7 @@ class Attention(nn.Module): hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + transformer_options={}, **cross_attention_kwargs, ) -> torch.Tensor: return self.processor( @@ -140,6 +141,7 @@ class Attention(nn.Module): hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + transformer_options=transformer_options, **cross_attention_kwargs, ) @@ -366,6 +368,7 @@ class CustomerAttnProcessor2_0: encoder_attention_mask: Optional[torch.FloatTensor] = None, rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, + transformer_options={}, *args, **kwargs, ) -> torch.Tensor: @@ -433,7 +436,7 @@ class CustomerAttnProcessor2_0: # the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = optimized_attention( - query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, + query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options, ).to(query.dtype) # linear proj @@ -697,6 +700,7 @@ class LinearTransformerBlock(nn.Module): rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, temb: torch.FloatTensor = None, + transformer_options={}, ): N = hidden_states.shape[0] @@ -720,6 +724,7 @@ class LinearTransformerBlock(nn.Module): encoder_attention_mask=encoder_attention_mask, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=rotary_freqs_cis_cross, + transformer_options=transformer_options, ) else: attn_output, _ = self.attn( @@ -729,6 +734,7 @@ class LinearTransformerBlock(nn.Module): encoder_attention_mask=None, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=None, + transformer_options=transformer_options, ) if self.use_adaln_single: @@ -743,6 +749,7 @@ class LinearTransformerBlock(nn.Module): encoder_attention_mask=encoder_attention_mask, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=rotary_freqs_cis_cross, + transformer_options=transformer_options, ) hidden_states = attn_output + hidden_states diff --git a/comfy/ldm/ace/model.py b/comfy/ldm/ace/model.py index 41d85eeb5..399329853 100644 --- a/comfy/ldm/ace/model.py +++ b/comfy/ldm/ace/model.py @@ -314,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module): output_length: int = 0, block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, controlnet_scale: Union[float, torch.Tensor] = 1.0, + transformer_options={}, ): embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype)) temb = self.t_block(embedded_timestep) @@ -339,6 +340,7 @@ class ACEStepTransformer2DModel(nn.Module): rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=encoder_rotary_freqs_cis, temb=temb, + transformer_options=transformer_options, ) output = self.final_layer(hidden_states, embedded_timestep, output_length) @@ -393,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module): output_length = hidden_states.shape[-1] + transformer_options = kwargs.get("transformer_options", {}) output = self.decode( hidden_states=hidden_states, attention_mask=attention_mask, @@ -402,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module): output_length=output_length, block_controlnet_hidden_states=block_controlnet_hidden_states, controlnet_scale=controlnet_scale, + transformer_options=transformer_options, ) return output diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index 179c5b67e..5ca0e1e52 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -298,7 +298,8 @@ class Attention(nn.Module): mask = None, context_mask = None, rotary_pos_emb = None, - causal = None + causal = None, + transformer_options={}, ): h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None @@ -363,7 +364,7 @@ class Attention(nn.Module): heads_per_kv_head = h // kv_h k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) - out = optimized_attention(q, k, v, h, skip_reshape=True) + out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) out = self.to_out(out) if mask is not None: @@ -488,7 +489,8 @@ class TransformerBlock(nn.Module): global_cond=None, mask = None, context_mask = None, - rotary_pos_emb = None + rotary_pos_emb = None, + transformer_options={} ): if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: @@ -498,12 +500,12 @@ class TransformerBlock(nn.Module): residual = x x = self.pre_norm(x) x = x * (1 + scale_self) + shift_self - x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb) + x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options) x = x * torch.sigmoid(1 - gate_self) x = x + residual if context is not None: - x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options) if self.conformer is not None: x = x + self.conformer(x) @@ -517,10 +519,10 @@ class TransformerBlock(nn.Module): x = x + residual else: - x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb) + x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options) if context is not None: - x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options) if self.conformer is not None: x = x + self.conformer(x) @@ -606,7 +608,8 @@ class ContinuousTransformer(nn.Module): return_info = False, **kwargs ): - patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {}) + transformer_options = kwargs.get("transformer_options", {}) + patches_replace = transformer_options.get("patches_replace", {}) batch, seq, device = *x.shape[:2], x.device context = kwargs["context"] @@ -645,13 +648,13 @@ class ContinuousTransformer(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"]) + out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context) + x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options) # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) if return_info: diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index d7f32b5e8..66d9613b6 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -85,7 +85,7 @@ class SingleAttention(nn.Module): ) #@torch.compile() - def forward(self, c): + def forward(self, c, transformer_options={}): bsz, seqlen1, _ = c.shape @@ -95,7 +95,7 @@ class SingleAttention(nn.Module): v = v.view(bsz, seqlen1, self.n_heads, self.head_dim) q, k = self.q_norm1(q), self.k_norm1(k) - output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) + output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options) c = self.w1o(output) return c @@ -144,7 +144,7 @@ class DoubleAttention(nn.Module): #@torch.compile() - def forward(self, c, x): + def forward(self, c, x, transformer_options={}): bsz, seqlen1, _ = c.shape bsz, seqlen2, _ = x.shape @@ -168,7 +168,7 @@ class DoubleAttention(nn.Module): torch.cat([cv, xv], dim=1), ) - output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) + output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options) c, x = output.split([seqlen1, seqlen2], dim=1) c = self.w1o(c) @@ -207,7 +207,7 @@ class MMDiTBlock(nn.Module): self.is_last = is_last #@torch.compile() - def forward(self, c, x, global_cond, **kwargs): + def forward(self, c, x, global_cond, transformer_options={}, **kwargs): cres, xres = c, x @@ -225,7 +225,7 @@ class MMDiTBlock(nn.Module): x = modulate(self.normX1(x), xshift_msa, xscale_msa) # attention - c, x = self.attn(c, x) + c, x = self.attn(c, x, transformer_options=transformer_options) c = self.normC2(cres + cgate_msa.unsqueeze(1) * c) @@ -255,13 +255,13 @@ class DiTBlock(nn.Module): self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations) #@torch.compile() - def forward(self, cx, global_cond, **kwargs): + def forward(self, cx, global_cond, transformer_options={}, **kwargs): cxres = cx shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX( global_cond ).chunk(6, dim=1) cx = modulate(self.norm1(cx), shift_msa, scale_msa) - cx = self.attn(cx) + cx = self.attn(cx, transformer_options=transformer_options) cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx) mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp)) cx = gate_mlp.unsqueeze(1) * mlpout @@ -473,13 +473,14 @@ class MMDiT(nn.Module): out = {} out["txt"], out["img"] = layer(args["txt"], args["img"], - args["vec"]) + args["vec"], + transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap}) c = out["txt"] x = out["img"] else: - c, x = layer(c, x, global_cond, **kwargs) + c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs) if len(self.single_layers) > 0: c_len = c.size(1) @@ -488,13 +489,13 @@ class MMDiT(nn.Module): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = layer(args["img"], args["vec"]) + out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap}) + out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap}) cx = out["img"] else: - cx = layer(cx, global_cond, **kwargs) + cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs) x = cx[:, c_len:] diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py index 3eaa0c821..42ef98c7a 100644 --- a/comfy/ldm/cascade/common.py +++ b/comfy/ldm/cascade/common.py @@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module): self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) - def forward(self, q, k, v): + def forward(self, q, k, v, transformer_options={}): q = self.to_q(q) k = self.to_k(k) v = self.to_v(v) - out = optimized_attention(q, k, v, self.heads) + out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options) return self.out_proj(out) @@ -47,13 +47,13 @@ class Attention2D(nn.Module): self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations) # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device) - def forward(self, x, kv, self_attn=False): + def forward(self, x, kv, self_attn=False, transformer_options={}): orig_shape = x.shape x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 if self_attn: kv = torch.cat([x, kv], dim=1) # x = self.attn(x, kv, kv, need_weights=False)[0] - x = self.attn(x, kv, kv) + x = self.attn(x, kv, kv, transformer_options=transformer_options) x = x.permute(0, 2, 1).view(*orig_shape) return x @@ -114,9 +114,9 @@ class AttnBlock(nn.Module): operations.Linear(c_cond, c, dtype=dtype, device=device) ) - def forward(self, x, kv): + def forward(self, x, kv, transformer_options={}): kv = self.kv_mapper(kv) - x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options) return x diff --git a/comfy/ldm/cascade/stage_b.py b/comfy/ldm/cascade/stage_b.py index 773830956..428c67fdf 100644 --- a/comfy/ldm/cascade/stage_b.py +++ b/comfy/ldm/cascade/stage_b.py @@ -173,7 +173,7 @@ class StageB(nn.Module): clip = self.clip_norm(clip) return clip - def _down_encode(self, x, r_embed, clip): + def _down_encode(self, x, r_embed, clip, transformer_options={}): level_outputs = [] block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) for down_block, downscaler, repmap in block_group: @@ -187,7 +187,7 @@ class StageB(nn.Module): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -199,7 +199,7 @@ class StageB(nn.Module): level_outputs.insert(0, x) return level_outputs - def _up_decode(self, level_outputs, r_embed, clip): + def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}): x = level_outputs[0] block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) for i, (up_block, upscaler, repmap) in enumerate(block_group): @@ -216,7 +216,7 @@ class StageB(nn.Module): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -228,7 +228,7 @@ class StageB(nn.Module): x = upscaler(x) return x - def forward(self, x, r, effnet, clip, pixels=None, **kwargs): + def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs): if pixels is None: pixels = x.new_zeros(x.size(0), 3, 8, 8) @@ -245,8 +245,8 @@ class StageB(nn.Module): nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True)) x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear', align_corners=True) - level_outputs = self._down_encode(x, r_embed, clip) - x = self._up_decode(level_outputs, r_embed, clip) + level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options) + x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options) return self.clf(x) def update_weights_ema(self, src_model, beta=0.999): diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py index b952d0349..ebc4434e2 100644 --- a/comfy/ldm/cascade/stage_c.py +++ b/comfy/ldm/cascade/stage_c.py @@ -182,7 +182,7 @@ class StageC(nn.Module): clip = self.clip_norm(clip) return clip - def _down_encode(self, x, r_embed, clip, cnet=None): + def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}): level_outputs = [] block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) for down_block, downscaler, repmap in block_group: @@ -201,7 +201,7 @@ class StageC(nn.Module): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -213,7 +213,7 @@ class StageC(nn.Module): level_outputs.insert(0, x) return level_outputs - def _up_decode(self, level_outputs, r_embed, clip, cnet=None): + def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}): x = level_outputs[0] block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) for i, (up_block, upscaler, repmap) in enumerate(block_group): @@ -235,7 +235,7 @@ class StageC(nn.Module): elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, AttnBlock)): - x = block(x, clip) + x = block(x, clip, transformer_options=transformer_options) elif isinstance(block, TimestepBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, TimestepBlock)): @@ -247,7 +247,7 @@ class StageC(nn.Module): x = upscaler(x) return x - def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs): + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs): # Process the conditioning embeddings r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) for c in self.t_conds: @@ -262,8 +262,8 @@ class StageC(nn.Module): # Model Blocks x = self.embedding(x) - level_outputs = self._down_encode(x, r_embed, clip, cnet) - x = self._up_decode(level_outputs, r_embed, clip, cnet) + level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options) + x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options) return self.clf(x) def update_weights_ema(self, src_model, beta=0.999): diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 2a0dec606..fc7110cce 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -76,7 +76,7 @@ class DoubleStreamBlock(nn.Module): ) self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None): + def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}): (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec # prepare image for attention @@ -95,7 +95,7 @@ class DoubleStreamBlock(nn.Module): attn = attention(torch.cat((txt_q, img_q), dim=2), torch.cat((txt_k, img_k), dim=2), torch.cat((txt_v, img_v), dim=2), - pe=pe, mask=attn_mask) + pe=pe, mask=attn_mask, transformer_options=transformer_options) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] @@ -148,7 +148,7 @@ class SingleStreamBlock(nn.Module): self.mlp_act = nn.GELU(approximate="tanh") - def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor: + def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor: mod = vec x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x)) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) @@ -157,7 +157,7 @@ class SingleStreamBlock(nn.Module): q, k = self.norm(q, k, v) # compute attention - attn = attention(q, k, v, pe=pe, mask=attn_mask) + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) x.addcmul_(mod.gate, output) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 5cff44dc8..4f709f87d 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -193,14 +193,16 @@ class Chroma(nn.Module): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": double_mod, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] @@ -209,7 +211,8 @@ class Chroma(nn.Module): txt=txt, vec=double_mod, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, + transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") @@ -229,17 +232,19 @@ class Chroma(nn.Module): out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": single_mod, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask) + img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index 5c4356a3f..afb43d469 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -176,6 +176,7 @@ class Attention(nn.Module): context=None, mask=None, rope_emb=None, + transformer_options={}, **kwargs, ): """ @@ -184,7 +185,7 @@ class Attention(nn.Module): context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None """ q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) - out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True) + out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options) del q, k, v out = rearrange(out, " b n s c -> s b (n c)") return self.to_out(out) @@ -546,6 +547,7 @@ class VideoAttn(nn.Module): context: Optional[torch.Tensor] = None, crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Forward pass for video attention. @@ -571,6 +573,7 @@ class VideoAttn(nn.Module): context_M_B_D, crossattn_mask, rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W) return x_T_H_W_B_D @@ -665,6 +668,7 @@ class DITBuildingBlock(nn.Module): crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Forward pass for dynamically configured blocks with adaptive normalization. @@ -702,6 +706,7 @@ class DITBuildingBlock(nn.Module): adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), context=None, rope_emb_L_1_1_D=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) elif self.block_type in ["cross_attn", "ca"]: x = x + gate_1_1_1_B_D * self.block( @@ -709,6 +714,7 @@ class DITBuildingBlock(nn.Module): context=crossattn_emb, crossattn_mask=crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, + transformer_options=transformer_options, ) else: raise ValueError(f"Unknown block type: {self.block_type}") @@ -784,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module): crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: for block in self.blocks: x = block( @@ -793,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module): crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, adaln_lora_B_3D=adaln_lora_B_3D, + transformer_options=transformer_options, ) return x diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py index 53698b758..52ef7ef43 100644 --- a/comfy/ldm/cosmos/model.py +++ b/comfy/ldm/cosmos/model.py @@ -520,6 +520,7 @@ class GeneralDIT(nn.Module): x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" + transformer_options = kwargs.get("transformer_options", {}) for _, block in self.blocks.items(): assert ( self.blocks["block0"].x_format == block.x_format @@ -534,6 +535,7 @@ class GeneralDIT(nn.Module): crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, adaln_lora_B_3D=adaln_lora_B_3D, + transformer_options=transformer_options, ) x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index fcc83ba76..07a4fc79f 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -44,7 +44,7 @@ class GPT2FeedForward(nn.Module): return x -def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor: +def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: """Computes multi-head attention using PyTorch's native implementation. This function provides a PyTorch backend alternative to Transformer Engine's attention operation. @@ -71,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) - return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True) + return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options) class Attention(nn.Module): @@ -180,8 +180,8 @@ class Attention(nn.Module): return q, k, v - def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - result = self.attn_op(q, k, v) # [B, S, H, D] + def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: + result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D] return self.output_dropout(self.output_proj(result)) def forward( @@ -189,6 +189,7 @@ class Attention(nn.Module): x: torch.Tensor, context: Optional[torch.Tensor] = None, rope_emb: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: """ Args: @@ -196,7 +197,7 @@ class Attention(nn.Module): context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None """ q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) - return self.compute_attention(q, k, v) + return self.compute_attention(q, k, v, transformer_options=transformer_options) class Timesteps(nn.Module): @@ -459,6 +460,7 @@ class Block(nn.Module): rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: if extra_per_block_pos_emb is not None: x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb @@ -512,6 +514,7 @@ class Block(nn.Module): rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), None, rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, ), "b (t h w) d -> b t h w d", t=T, @@ -525,6 +528,7 @@ class Block(nn.Module): layer_norm_cross_attn: Callable, _scale_cross_attn_B_T_1_1_D: torch.Tensor, _shift_cross_attn_B_T_1_1_D: torch.Tensor, + transformer_options: Optional[dict] = {}, ) -> torch.Tensor: _normalized_x_B_T_H_W_D = _fn( _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D @@ -534,6 +538,7 @@ class Block(nn.Module): rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), crossattn_emb, rope_emb=rope_emb_L_1_1_D, + transformer_options=transformer_options, ), "b (t h w) d -> b t h w d", t=T, @@ -547,6 +552,7 @@ class Block(nn.Module): self.layer_norm_cross_attn, scale_cross_attn_B_T_1_1_D, shift_cross_attn_B_T_1_1_D, + transformer_options=transformer_options, ) x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D @@ -865,6 +871,7 @@ class MiniTrainDIT(nn.Module): "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0), "adaln_lora_B_T_3D": adaln_lora_B_T_3D, "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + "transformer_options": kwargs.get("transformer_options", {}), } for block in self.blocks: x_B_T_H_W_D = block( diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 113eb2096..ef21b416b 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -159,7 +159,7 @@ class DoubleStreamBlock(nn.Module): ) self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None): + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) @@ -182,7 +182,7 @@ class DoubleStreamBlock(nn.Module): attn = attention(torch.cat((img_q, txt_q), dim=2), torch.cat((img_k, txt_k), dim=2), torch.cat((img_v, txt_v), dim=2), - pe=pe, mask=attn_mask) + pe=pe, mask=attn_mask, transformer_options=transformer_options) img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:] else: @@ -190,7 +190,7 @@ class DoubleStreamBlock(nn.Module): attn = attention(torch.cat((txt_q, img_q), dim=2), torch.cat((txt_k, img_k), dim=2), torch.cat((txt_v, img_v), dim=2), - pe=pe, mask=attn_mask) + pe=pe, mask=attn_mask, transformer_options=transformer_options) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] @@ -244,7 +244,7 @@ class SingleStreamBlock(nn.Module): self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) - def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor: + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor: mod, _ = self.modulation(vec) qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) @@ -252,7 +252,7 @@ class SingleStreamBlock(nn.Module): q, k = self.norm(q, k, v) # compute attention - attn = attention(q, k, v, pe=pe, mask=attn_mask) + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) x += apply_mod(output, mod.gate, None, modulation_dims) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 3e0978176..4d743cda2 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -6,7 +6,7 @@ from comfy.ldm.modules.attention import optimized_attention import comfy.model_management -def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: q_shape = q.shape k_shape = k.shape @@ -17,7 +17,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) heads = q.shape[1] - x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) + x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 1344c3a57..2f03b0552 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -135,14 +135,16 @@ class Flux(nn.Module): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] @@ -151,7 +153,8 @@ class Flux(nn.Module): txt=txt, vec=vec, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, + transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") @@ -172,17 +175,19 @@ class Flux(nn.Module): out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options")) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") diff --git a/comfy/ldm/genmo/joint_model/asymm_models_joint.py b/comfy/ldm/genmo/joint_model/asymm_models_joint.py index 366a8b713..5c1bb4d42 100644 --- a/comfy/ldm/genmo/joint_model/asymm_models_joint.py +++ b/comfy/ldm/genmo/joint_model/asymm_models_joint.py @@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module): scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. crop_y, + transformer_options={}, **rope_rotation, ) -> Tuple[torch.Tensor, torch.Tensor]: rope_cos = rope_rotation.get("rope_cos") @@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module): xy = optimized_attention(q, k, - v, self.num_heads, skip_reshape=True) + v, self.num_heads, skip_reshape=True, transformer_options=transformer_options) x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1) x = self.proj_x(x) @@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module): x: torch.Tensor, c: torch.Tensor, y: torch.Tensor, + transformer_options={}, **attn_kwargs, ): """Forward pass of a block. @@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module): y, scale_x=scale_msa_x, scale_y=scale_msa_y, + transformer_options=transformer_options, **attn_kwargs, ) @@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module): args["txt"], rope_cos=args["rope_cos"], rope_sin=args["rope_sin"], - crop_y=args["num_tokens"] + crop_y=args["num_tokens"], + transformer_options=args["transformer_options"] ) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap}) y_feat = out["txt"] x = out["img"] else: @@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module): rope_cos=rope_cos, rope_sin=rope_sin, crop_y=num_tokens, + transformer_options=transformer_options, ) # (B, M, D), (B, L, D) del y_feat # Final layers don't use dense text features. diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py index ae49cf945..28d81c79e 100644 --- a/comfy/ldm/hidream/model.py +++ b/comfy/ldm/hidream/model.py @@ -72,8 +72,8 @@ class TimestepEmbed(nn.Module): return t_emb -def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): - return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2]) +def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}): + return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options) class HiDreamAttnProcessor_flashattn: @@ -86,6 +86,7 @@ class HiDreamAttnProcessor_flashattn: image_tokens_masks: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, + transformer_options={}, *args, **kwargs, ) -> torch.FloatTensor: @@ -133,7 +134,7 @@ class HiDreamAttnProcessor_flashattn: query = torch.cat([query_1, query_2], dim=-1) key = torch.cat([key_1, key_2], dim=-1) - hidden_states = attention(query, key, value) + hidden_states = attention(query, key, value, transformer_options=transformer_options) if not attn.single: hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) @@ -199,6 +200,7 @@ class HiDreamAttention(nn.Module): image_tokens_masks: torch.FloatTensor = None, norm_text_tokens: torch.FloatTensor = None, rope: torch.FloatTensor = None, + transformer_options={}, ) -> torch.Tensor: return self.processor( self, @@ -206,6 +208,7 @@ class HiDreamAttention(nn.Module): image_tokens_masks = image_tokens_masks, text_tokens = norm_text_tokens, rope = rope, + transformer_options=transformer_options, ) @@ -406,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module): text_tokens: Optional[torch.FloatTensor] = None, adaln_input: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, - + transformer_options={}, ) -> torch.FloatTensor: wtype = image_tokens.dtype shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \ @@ -419,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module): norm_image_tokens, image_tokens_masks, rope = rope, + transformer_options=transformer_options, ) image_tokens = gate_msa_i * attn_output_i + image_tokens @@ -483,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module): text_tokens: Optional[torch.FloatTensor] = None, adaln_input: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, + transformer_options={}, ) -> torch.FloatTensor: wtype = image_tokens.dtype shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ @@ -500,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module): image_tokens_masks, norm_text_tokens, rope = rope, + transformer_options=transformer_options, ) image_tokens = gate_msa_i * attn_output_i + image_tokens @@ -550,6 +556,7 @@ class HiDreamImageBlock(nn.Module): text_tokens: Optional[torch.FloatTensor] = None, adaln_input: torch.FloatTensor = None, rope: torch.FloatTensor = None, + transformer_options={}, ) -> torch.FloatTensor: return self.block( image_tokens, @@ -557,6 +564,7 @@ class HiDreamImageBlock(nn.Module): text_tokens, adaln_input, rope, + transformer_options=transformer_options, ) @@ -786,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module): text_tokens = cur_encoder_hidden_states, adaln_input = adaln_input, rope = rope, + transformer_options=transformer_options, ) initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] block_id += 1 @@ -809,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module): text_tokens=None, adaln_input=adaln_input, rope=rope, + transformer_options=transformer_options, ) hidden_states = hidden_states[:, :hidden_states_seq_len] block_id += 1 diff --git a/comfy/ldm/hunyuan3d/model.py b/comfy/ldm/hunyuan3d/model.py index 0fa5e78c1..4991b1645 100644 --- a/comfy/ldm/hunyuan3d/model.py +++ b/comfy/ldm/hunyuan3d/model.py @@ -99,14 +99,16 @@ class Hunyuan3Dv2(nn.Module): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args["transformer_options"]) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] @@ -115,7 +117,8 @@ class Hunyuan3Dv2(nn.Module): txt=txt, vec=vec, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, + transformer_options=transformer_options) img = torch.cat((txt, img), 1) @@ -126,17 +129,19 @@ class Hunyuan3Dv2(nn.Module): out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"), + transformer_options=args["transformer_options"]) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, - "attn_mask": attn_mask}, + "attn_mask": attn_mask, + "transformer_options": transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) img = img[:, txt.shape[1]:, ...] img = self.final_layer(img, vec) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index da1011596..17089c701 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -78,13 +78,13 @@ class TokenRefinerBlock(nn.Module): operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) - def forward(self, x, c, mask): + def forward(self, x, c, mask, transformer_options={}): mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1) norm_x = self.norm1(x) qkv = self.self_attn.qkv(norm_x) q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4) - attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True) + attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options) x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1) x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1) @@ -115,14 +115,14 @@ class IndividualTokenRefiner(nn.Module): ] ) - def forward(self, x, c, mask): + def forward(self, x, c, mask, transformer_options={}): m = None if mask is not None: m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1) m = m + m.transpose(2, 3) for block in self.blocks: - x = block(x, c, m) + x = block(x, c, m, transformer_options=transformer_options) return x @@ -150,6 +150,7 @@ class TokenRefiner(nn.Module): x, timesteps, mask, + transformer_options={}, ): t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype)) # m = mask.float().unsqueeze(-1) @@ -158,7 +159,7 @@ class TokenRefiner(nn.Module): c = t + self.c_embedder(c.to(x.dtype)) x = self.input_embedder(x) - x = self.individual_token_refiner(x, c, mask) + x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options) return x class HunyuanVideo(nn.Module): @@ -267,7 +268,7 @@ class HunyuanVideo(nn.Module): if txt_mask is not None and not torch.is_floating_point(txt_mask): txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max - txt = self.txt_in(txt, timesteps, txt_mask) + txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options) ids = torch.cat((img_ids, txt_ids), dim=1) pe = self.pe_embedder(ids) @@ -285,14 +286,14 @@ class HunyuanVideo(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"]) + out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] else: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") @@ -307,13 +308,13 @@ class HunyuanVideo(nn.Module): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"]) + out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap}) + out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index aa2ea62b1..def365ba7 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -271,7 +271,7 @@ class CrossAttention(nn.Module): self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - def forward(self, x, context=None, mask=None, pe=None): + def forward(self, x, context=None, mask=None, pe=None, transformer_options={}): q = self.to_q(x) context = x if context is None else context k = self.to_k(context) @@ -285,9 +285,9 @@ class CrossAttention(nn.Module): k = apply_rotary_emb(k, pe) if mask is None: - out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) + out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: - out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) + out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) return self.to_out(out) @@ -303,12 +303,12 @@ class BasicTransformerBlock(nn.Module): self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) - def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None): + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) - x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa + x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa - x += self.attn2(x, context=context, mask=attention_mask) + x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp x += self.ff(y) * gate_mlp @@ -479,10 +479,10 @@ class LTXVModel(torch.nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"]) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -490,7 +490,8 @@ class LTXVModel(torch.nn.Module): context=context, attention_mask=attention_mask, timestep=timestep, - pe=pe + pe=pe, + transformer_options=transformer_options, ) # 3. Output diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index e08ed817d..f87d98ac0 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -104,6 +104,7 @@ class JointAttention(nn.Module): x: torch.Tensor, x_mask: torch.Tensor, freqs_cis: torch.Tensor, + transformer_options={}, ) -> torch.Tensor: """ @@ -140,7 +141,7 @@ class JointAttention(nn.Module): if n_rep >= 1: xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True) + output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options) return self.out(output) @@ -268,6 +269,7 @@ class JointTransformerBlock(nn.Module): x_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor]=None, + transformer_options={}, ): """ Perform a forward pass through the TransformerBlock. @@ -290,6 +292,7 @@ class JointTransformerBlock(nn.Module): modulate(self.attention_norm1(x), scale_msa), x_mask, freqs_cis, + transformer_options=transformer_options, ) ) x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( @@ -304,6 +307,7 @@ class JointTransformerBlock(nn.Module): self.attention_norm1(x), x_mask, freqs_cis, + transformer_options=transformer_options, ) ) x = x + self.ffn_norm2( @@ -494,7 +498,7 @@ class NextDiT(nn.Module): return imgs def patchify_and_embed( - self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens + self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={} ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: bsz = len(x) pH = pW = self.patch_size @@ -554,7 +558,7 @@ class NextDiT(nn.Module): # refine context for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) + cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) # refine image flat_x = [] @@ -573,7 +577,7 @@ class NextDiT(nn.Module): padded_img_embed = self.x_embedder(padded_img_embed) padded_img_mask = padded_img_mask.unsqueeze(1) for layer in self.noise_refiner: - padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) + padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options) if cap_mask is not None: mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device) @@ -616,12 +620,13 @@ class NextDiT(nn.Module): cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute + transformer_options = kwargs.get("transformer_options", {}) x_is_tensor = isinstance(x, torch.Tensor) - x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens) + x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) freqs_cis = freqs_cis.to(x.device) for layer in self.layers: - x = layer(x, mask, freqs_cis, adaln_input) + x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options) x = self.final_layer(x, adaln_input) x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 043df28df..bf2553c37 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -5,8 +5,9 @@ import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat -from typing import Optional +from typing import Optional, Any, Callable, Union import logging +import functools from .diffusionmodules.util import AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention @@ -17,23 +18,45 @@ if model_management.xformers_enabled(): import xformers import xformers.ops -if model_management.sage_attention_enabled(): - try: - from sageattention import sageattn - except ModuleNotFoundError as e: +SAGE_ATTENTION_IS_AVAILABLE = False +try: + from sageattention import sageattn + SAGE_ATTENTION_IS_AVAILABLE = True +except ModuleNotFoundError as e: + if model_management.sage_attention_enabled(): if e.name == "sageattention": logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") else: raise e exit(-1) -if model_management.flash_attention_enabled(): - try: - from flash_attn import flash_attn_func - except ModuleNotFoundError: +FLASH_ATTENTION_IS_AVAILABLE = False +try: + from flash_attn import flash_attn_func + FLASH_ATTENTION_IS_AVAILABLE = True +except ModuleNotFoundError: + if model_management.flash_attention_enabled(): logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") exit(-1) +REGISTERED_ATTENTION_FUNCTIONS = {} +def register_attention_function(name: str, func: Callable): + # avoid replacing existing functions + if name not in REGISTERED_ATTENTION_FUNCTIONS: + REGISTERED_ATTENTION_FUNCTIONS[name] = func + else: + logging.warning(f"Attention function {name} already registered, skipping registration.") + +def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]: + if name == "optimized": + return optimized_attention + elif name not in REGISTERED_ATTENTION_FUNCTIONS: + if default is ...: + raise KeyError(f"Attention function {name} not found.") + else: + return default + return REGISTERED_ATTENTION_FUNCTIONS[name] + from comfy.cli_args import args import comfy.ops ops = comfy.ops.disable_weight_init @@ -91,7 +114,27 @@ class FeedForward(nn.Module): def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) -def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): + +def wrap_attn(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + remove_attn_wrapper_key = False + try: + if "_inside_attn_wrapper" not in kwargs: + transformer_options = kwargs.get("transformer_options", None) + remove_attn_wrapper_key = True + kwargs["_inside_attn_wrapper"] = True + if transformer_options is not None: + if "optimized_attention_override" in transformer_options: + return transformer_options["optimized_attention_override"](func, *args, **kwargs) + return func(*args, **kwargs) + finally: + if remove_attn_wrapper_key: + del kwargs["_inside_attn_wrapper"] + return wrapper + +@wrap_attn +def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: @@ -159,8 +202,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) return out - -def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, query.dtype) if skip_reshape: @@ -230,7 +273,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) return hidden_states -def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: @@ -359,7 +403,8 @@ try: except: pass -def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): b = q.shape[0] dim_head = q.shape[-1] # check to make sure xformers isn't broken @@ -374,7 +419,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh disabled_xformers = True if disabled_xformers: - return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) + return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs) if skip_reshape: # b h k d -> b k h d @@ -427,8 +472,8 @@ else: #TODO: other GPUs ? SDP_BATCH_LIMIT = 2**31 - -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape else: @@ -470,8 +515,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) return out - -def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape tensor_layout = "HND" @@ -501,7 +546,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= lambda t: t.transpose(1, 2), (q, k, v), ) - return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape) + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs) if tensor_layout == "HND": if not skip_output_reshape: @@ -534,8 +579,8 @@ except AttributeError as error: dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" - -def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +@wrap_attn +def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape else: @@ -597,6 +642,19 @@ else: optimized_attention_masked = optimized_attention + +# register core-supported attention functions +if SAGE_ATTENTION_IS_AVAILABLE: + register_attention_function("sage", attention_sage) +if FLASH_ATTENTION_IS_AVAILABLE: + register_attention_function("flash", attention_flash) +if model_management.xformers_enabled(): + register_attention_function("xformers", attention_xformers) +register_attention_function("pytorch", attention_pytorch) +register_attention_function("sub_quad", attention_sub_quad) +register_attention_function("split", attention_split) + + def optimized_attention_for_device(device, mask=False, small_input=False): if small_input: if model_management.pytorch_attention_enabled(): @@ -629,7 +687,7 @@ class CrossAttention(nn.Module): self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - def forward(self, x, context=None, value=None, mask=None): + def forward(self, x, context=None, value=None, mask=None, transformer_options={}): q = self.to_q(x) context = default(context, x) k = self.to_k(context) @@ -640,9 +698,9 @@ class CrossAttention(nn.Module): v = self.to_v(context) if mask is None: - out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) + out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: - out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) + out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) return self.to_out(out) @@ -746,7 +804,7 @@ class BasicTransformerBlock(nn.Module): n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) n = self.attn1.to_out(n) else: - n = self.attn1(n, context=context_attn1, value=value_attn1) + n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options) if "attn1_output_patch" in transformer_patches: patch = transformer_patches["attn1_output_patch"] @@ -786,7 +844,7 @@ class BasicTransformerBlock(nn.Module): n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) n = self.attn2.to_out(n) else: - n = self.attn2(n, context=context_attn2, value=value_attn2) + n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options) if "attn2_output_patch" in transformer_patches: patch = transformer_patches["attn2_output_patch"] @@ -1017,7 +1075,7 @@ class SpatialVideoTransformer(SpatialTransformer): B, S, C = x_mix.shape x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps) - x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options + x_mix = mix_block(x_mix, context=time_context, transformer_options=transformer_options) x_mix = rearrange( x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps ) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 4d6beba2d..42f406f1a 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -606,7 +606,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs): return _block_mixing(*args, **kwargs) -def _block_mixing(context, x, context_block, x_block, c): +def _block_mixing(context, x, context_block, x_block, c, transformer_options={}): context_qkv, context_intermediates = context_block.pre_attention(context, c) if x_block.x_block_self_attn: @@ -622,6 +622,7 @@ def _block_mixing(context, x, context_block, x_block, c): attn = optimized_attention( qkv[0], qkv[1], qkv[2], heads=x_block.attn.num_heads, + transformer_options=transformer_options, ) context_attn, x_attn = ( attn[:, : context_qkv[0].shape[1]], @@ -637,6 +638,7 @@ def _block_mixing(context, x, context_block, x_block, c): attn2 = optimized_attention( x_qkv2[0], x_qkv2[1], x_qkv2[2], heads=x_block.attn2.num_heads, + transformer_options=transformer_options, ) x = x_block.post_attention_x(x_attn, attn2, *x_intermediates) else: @@ -958,10 +960,10 @@ class MMDiT(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"]) + out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap}) context = out["txt"] x = out["img"] else: @@ -970,6 +972,7 @@ class MMDiT(nn.Module): x, c=c_mod, use_checkpoint=self.use_checkpoint, + transformer_options=transformer_options, ) if control is not None: control_o = control.get("output") diff --git a/comfy/ldm/omnigen/omnigen2.py b/comfy/ldm/omnigen/omnigen2.py index 4884449f8..82edc92da 100644 --- a/comfy/ldm/omnigen/omnigen2.py +++ b/comfy/ldm/omnigen/omnigen2.py @@ -120,7 +120,7 @@ class Attention(nn.Module): nn.Dropout(0.0) ) - def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape query = self.to_q(hidden_states) @@ -146,7 +146,7 @@ class Attention(nn.Module): key = key.repeat_interleave(self.heads // self.kv_heads, dim=1) value = value.repeat_interleave(self.heads // self.kv_heads, dim=1) - hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True) + hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options) hidden_states = self.to_out[0](hidden_states) return hidden_states @@ -182,16 +182,16 @@ class OmniGen2TransformerBlock(nn.Module): self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) - def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor: if self.modulation: norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) - attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb) + attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options) hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) else: norm_hidden_states = self.norm1(hidden_states) - attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb) + attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options) hidden_states = hidden_states + self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) hidden_states = hidden_states + self.ffn_norm2(mlp_output) @@ -390,7 +390,7 @@ class OmniGen2Transformer2DModel(nn.Module): ref_img_sizes, img_sizes, ) - def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb): + def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, transformer_options={}): batch_size = len(hidden_states) hidden_states = self.x_embedder(hidden_states) @@ -405,17 +405,17 @@ class OmniGen2Transformer2DModel(nn.Module): shift += ref_img_len for layer in self.noise_refiner: - hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) + hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb, transformer_options=transformer_options) if ref_image_hidden_states is not None: for layer in self.ref_image_refiner: - ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb) + ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb, transformer_options=transformer_options) hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1) return hidden_states - def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs): + def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs): B, C, H, W = x.shape hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) _, _, H_padded, W_padded = hidden_states.shape @@ -444,7 +444,7 @@ class OmniGen2Transformer2DModel(nn.Module): ) for layer in self.context_refiner: - text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb) + text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options) img_len = hidden_states.shape[1] combined_img_hidden_states = self.img_patch_embed_and_refine( @@ -453,13 +453,14 @@ class OmniGen2Transformer2DModel(nn.Module): noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, + transformer_options=transformer_options, ) hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1) attention_mask = None for layer in self.layers: - hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) + hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb, transformer_options=transformer_options) hidden_states = self.norm_out(hidden_states, temb) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 04071f31c..b9f60c2b7 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -132,6 +132,7 @@ class Attention(nn.Module): encoder_hidden_states_mask: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: seq_txt = encoder_hidden_states.shape[1] @@ -159,7 +160,7 @@ class Attention(nn.Module): joint_key = joint_key.flatten(start_dim=2) joint_value = joint_value.flatten(start_dim=2) - joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options) txt_attn_output = joint_hidden_states[:, :seq_txt, :] img_attn_output = joint_hidden_states[:, seq_txt:, :] @@ -226,6 +227,7 @@ class QwenImageTransformerBlock(nn.Module): encoder_hidden_states_mask: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: img_mod_params = self.img_mod(temb) txt_mod_params = self.txt_mod(temb) @@ -242,6 +244,7 @@ class QwenImageTransformerBlock(nn.Module): encoder_hidden_states=txt_modulated, encoder_hidden_states_mask=encoder_hidden_states_mask, image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, ) hidden_states = hidden_states + img_gate1 * img_attn_output @@ -434,9 +437,9 @@ class QwenImageTransformer2DModel(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"]) + out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) hidden_states = out["img"] encoder_hidden_states = out["txt"] else: @@ -446,11 +449,12 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, ) if "double_block" in patches: for p in patches["double_block"]: - out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i}) + out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options}) hidden_states = out["img"] encoder_hidden_states = out["txt"] diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 47857dc2b..63472ada2 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -52,7 +52,7 @@ class WanSelfAttention(nn.Module): self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, freqs): + def forward(self, x, freqs, transformer_options={}): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] @@ -75,6 +75,7 @@ class WanSelfAttention(nn.Module): k.view(b, s, n * d), v, heads=self.num_heads, + transformer_options=transformer_options, ) x = self.o(x) @@ -83,7 +84,7 @@ class WanSelfAttention(nn.Module): class WanT2VCrossAttention(WanSelfAttention): - def forward(self, x, context, **kwargs): + def forward(self, x, context, transformer_options={}, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] @@ -95,7 +96,7 @@ class WanT2VCrossAttention(WanSelfAttention): v = self.v(context) # compute attention - x = optimized_attention(q, k, v, heads=self.num_heads) + x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options) x = self.o(x) return x @@ -116,7 +117,7 @@ class WanI2VCrossAttention(WanSelfAttention): # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, context, context_img_len): + def forward(self, x, context, context_img_len, transformer_options={}): r""" Args: x(Tensor): Shape [B, L1, C] @@ -131,9 +132,9 @@ class WanI2VCrossAttention(WanSelfAttention): v = self.v(context) k_img = self.norm_k_img(self.k_img(context_img)) v_img = self.v_img(context_img) - img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads) + img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options) # compute attention - x = optimized_attention(q, k, v, heads=self.num_heads) + x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options) # output x = x + img_x @@ -206,6 +207,7 @@ class WanAttentionBlock(nn.Module): freqs, context, context_img_len=257, + transformer_options={}, ): r""" Args: @@ -224,12 +226,12 @@ class WanAttentionBlock(nn.Module): # self-attention y = self.self_attn( torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), - freqs) + freqs, transformer_options=transformer_options) x = torch.addcmul(x, y, repeat_e(e[2], x)) # cross-attention & ffn - x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) x = torch.addcmul(x, y, repeat_e(e[5], x)) return x @@ -559,12 +561,12 @@ class WanModel(torch.nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) # head x = self.head(x, e) @@ -742,17 +744,17 @@ class VaceWanModel(WanModel): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) ii = self.vace_layers_mapping.get(i, None) if ii is not None: for iii in range(len(c)): - c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) x += c_skip * vace_strength[iii] del c_skip # head @@ -841,12 +843,12 @@ class CameraWanModel(WanModel): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) # head x = self.head(x, e)