Looking into a @wrap_attn decorator to look for 'optimized_attention_override' entry in transformer_options

This commit is contained in:
Jedrzej Kosinski
2025-08-27 14:18:18 -07:00
parent b20ba1f27c
commit b58db6934c

View File

@@ -7,6 +7,7 @@ from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from typing import Optional from typing import Optional
import logging import logging
import functools
from .diffusionmodules.util import AlphaBlender, timestep_embedding from .diffusionmodules.util import AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
@@ -91,6 +92,17 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None): def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def wrap_attn(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
transformer_options = kwargs.pop("transformer_options", None)
if transformer_options is not None:
if "optimized_attention_override" in transformer_options:
return transformer_options["optimized_attention_override"](*args, **kwargs)
return func(*args, **kwargs)
return wrapper
@wrap_attn
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision, q.dtype) attn_precision = get_attn_precision(attn_precision, q.dtype)
@@ -159,7 +171,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
) )
return out return out
@wrap_attn
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision, query.dtype) attn_precision = get_attn_precision(attn_precision, query.dtype)
@@ -230,6 +242,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states return hidden_states
@wrap_attn
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision, q.dtype) attn_precision = get_attn_precision(attn_precision, q.dtype)
@@ -359,6 +372,7 @@ try:
except: except:
pass pass
@wrap_attn
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
b = q.shape[0] b = q.shape[0]
dim_head = q.shape[-1] dim_head = q.shape[-1]
@@ -427,7 +441,7 @@ else:
#TODO: other GPUs ? #TODO: other GPUs ?
SDP_BATCH_LIMIT = 2**31 SDP_BATCH_LIMIT = 2**31
@wrap_attn
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
@@ -470,7 +484,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out return out
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
@@ -534,7 +548,7 @@ except AttributeError as error:
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
@wrap_attn
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
@@ -629,7 +643,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, value=None, mask=None): def forward(self, x, context=None, value=None, mask=None, transformer_options={}):
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
k = self.to_k(context) k = self.to_k(context)
@@ -640,9 +654,9 @@ class CrossAttention(nn.Module):
v = self.to_v(context) v = self.to_v(context)
if mask is None: if mask is None:
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision) out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else: else:
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision) out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
return self.to_out(out) return self.to_out(out)