diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 043df28df..0ab5cf16d 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -7,6 +7,7 @@ from torch import nn, einsum from einops import rearrange, repeat from typing import Optional import logging +import functools from .diffusionmodules.util import AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention @@ -91,6 +92,17 @@ class FeedForward(nn.Module): def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) +def wrap_attn(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + transformer_options = kwargs.pop("transformer_options", None) + if transformer_options is not None: + if "optimized_attention_override" in transformer_options: + return transformer_options["optimized_attention_override"](*args, **kwargs) + return func(*args, **kwargs) + return wrapper + +@wrap_attn def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): attn_precision = get_attn_precision(attn_precision, q.dtype) @@ -159,7 +171,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) return out - +@wrap_attn def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): attn_precision = get_attn_precision(attn_precision, query.dtype) @@ -230,6 +242,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) return hidden_states +@wrap_attn def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): attn_precision = get_attn_precision(attn_precision, q.dtype) @@ -359,6 +372,7 @@ try: except: pass +@wrap_attn def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): b = q.shape[0] dim_head = q.shape[-1] @@ -427,7 +441,7 @@ else: #TODO: other GPUs ? SDP_BATCH_LIMIT = 2**31 - +@wrap_attn def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape @@ -470,7 +484,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) return out - +@wrap_attn def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape @@ -534,7 +548,7 @@ except AttributeError as error: dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" - +@wrap_attn def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape @@ -629,7 +643,7 @@ class CrossAttention(nn.Module): self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - def forward(self, x, context=None, value=None, mask=None): + def forward(self, x, context=None, value=None, mask=None, transformer_options={}): q = self.to_q(x) context = default(context, x) k = self.to_k(context) @@ -640,9 +654,9 @@ class CrossAttention(nn.Module): v = self.to_v(context) if mask is None: - out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) + out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: - out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) + out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) return self.to_out(out)