Added **kwargs to all attention functions so transformer_options could potentially be passed through

This commit is contained in:
Jedrzej Kosinski
2025-08-28 13:14:41 -07:00
parent dd21b4aa51
commit 669b9ef8e6

View File

@@ -194,7 +194,7 @@ def wrap_attn(func):
finally: finally:
# Important: break ref cycles so tensors aren't pinned # Important: break ref cycles so tensors aren't pinned
del frame del frame
transformer_options = kwargs.pop("transformer_options", None) transformer_options = kwargs.get("transformer_options", None)
if transformer_options is not None: if transformer_options is not None:
if "optimized_attention_override" in transformer_options: if "optimized_attention_override" in transformer_options:
return transformer_options["optimized_attention_override"](func, transformer_options, *args, **kwargs) return transformer_options["optimized_attention_override"](func, transformer_options, *args, **kwargs)
@@ -202,7 +202,7 @@ def wrap_attn(func):
return wrapper return wrapper
@wrap_attn @wrap_attn
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, q.dtype) attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape: if skip_reshape:
@@ -271,7 +271,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
return out return out
@wrap_attn @wrap_attn
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, query.dtype) attn_precision = get_attn_precision(attn_precision, query.dtype)
if skip_reshape: if skip_reshape:
@@ -342,7 +342,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
return hidden_states return hidden_states
@wrap_attn @wrap_attn
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, q.dtype) attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape: if skip_reshape:
@@ -472,7 +472,7 @@ except:
pass pass
@wrap_attn @wrap_attn
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
b = q.shape[0] b = q.shape[0]
dim_head = q.shape[-1] dim_head = q.shape[-1]
# check to make sure xformers isn't broken # check to make sure xformers isn't broken
@@ -487,7 +487,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True disabled_xformers = True
if disabled_xformers: if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
if skip_reshape: if skip_reshape:
# b h k d -> b k h d # b h k d -> b k h d
@@ -541,7 +541,7 @@ else:
SDP_BATCH_LIMIT = 2**31 SDP_BATCH_LIMIT = 2**31
@wrap_attn @wrap_attn
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
else: else:
@@ -584,7 +584,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
return out return out
@wrap_attn @wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
tensor_layout = "HND" tensor_layout = "HND"
@@ -614,7 +614,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.transpose(1, 2), lambda t: t.transpose(1, 2),
(q, k, v), (q, k, v),
) )
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape) return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
if tensor_layout == "HND": if tensor_layout == "HND":
if not skip_output_reshape: if not skip_output_reshape:
@@ -648,7 +648,7 @@ except AttributeError as error:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
@wrap_attn @wrap_attn
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
else: else: