Added **kwargs to all attention functions so transformer_options could potentially be passed through

This commit is contained in:
Jedrzej Kosinski
2025-08-28 13:14:41 -07:00
parent dd21b4aa51
commit 669b9ef8e6

View File

@@ -194,7 +194,7 @@ def wrap_attn(func):
finally:
# Important: break ref cycles so tensors aren't pinned
del frame
transformer_options = kwargs.pop("transformer_options", None)
transformer_options = kwargs.get("transformer_options", None)
if transformer_options is not None:
if "optimized_attention_override" in transformer_options:
return transformer_options["optimized_attention_override"](func, transformer_options, *args, **kwargs)
@@ -202,7 +202,7 @@ def wrap_attn(func):
return wrapper
@wrap_attn
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
@@ -271,7 +271,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
return out
@wrap_attn
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, query.dtype)
if skip_reshape:
@@ -342,7 +342,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
return hidden_states
@wrap_attn
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
@@ -472,7 +472,7 @@ except:
pass
@wrap_attn
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
@@ -487,7 +487,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True
if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
if skip_reshape:
# b h k d -> b k h d
@@ -541,7 +541,7 @@ else:
SDP_BATCH_LIMIT = 2**31
@wrap_attn
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@@ -584,7 +584,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
return out
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
@@ -614,7 +614,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.transpose(1, 2),
(q, k, v),
)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
if tensor_layout == "HND":
if not skip_output_reshape:
@@ -648,7 +648,7 @@ except AttributeError as error:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
@wrap_attn
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
else: