mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 20:48:22 +00:00
Added **kwargs to all attention functions so transformer_options could potentially be passed through
This commit is contained in:
@@ -194,7 +194,7 @@ def wrap_attn(func):
|
|||||||
finally:
|
finally:
|
||||||
# Important: break ref cycles so tensors aren't pinned
|
# Important: break ref cycles so tensors aren't pinned
|
||||||
del frame
|
del frame
|
||||||
transformer_options = kwargs.pop("transformer_options", None)
|
transformer_options = kwargs.get("transformer_options", None)
|
||||||
if transformer_options is not None:
|
if transformer_options is not None:
|
||||||
if "optimized_attention_override" in transformer_options:
|
if "optimized_attention_override" in transformer_options:
|
||||||
return transformer_options["optimized_attention_override"](func, transformer_options, *args, **kwargs)
|
return transformer_options["optimized_attention_override"](func, transformer_options, *args, **kwargs)
|
||||||
@@ -202,7 +202,7 @@ def wrap_attn(func):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@@ -271,7 +271,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
attn_precision = get_attn_precision(attn_precision, query.dtype)
|
attn_precision = get_attn_precision(attn_precision, query.dtype)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@@ -342,7 +342,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@@ -472,7 +472,7 @@ except:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
b = q.shape[0]
|
b = q.shape[0]
|
||||||
dim_head = q.shape[-1]
|
dim_head = q.shape[-1]
|
||||||
# check to make sure xformers isn't broken
|
# check to make sure xformers isn't broken
|
||||||
@@ -487,7 +487,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||||||
disabled_xformers = True
|
disabled_xformers = True
|
||||||
|
|
||||||
if disabled_xformers:
|
if disabled_xformers:
|
||||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
|
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
# b h k d -> b k h d
|
# b h k d -> b k h d
|
||||||
@@ -541,7 +541,7 @@ else:
|
|||||||
SDP_BATCH_LIMIT = 2**31
|
SDP_BATCH_LIMIT = 2**31
|
||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
else:
|
else:
|
||||||
@@ -584,7 +584,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
tensor_layout = "HND"
|
tensor_layout = "HND"
|
||||||
@@ -614,7 +614,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
lambda t: t.transpose(1, 2),
|
lambda t: t.transpose(1, 2),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
|
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||||
|
|
||||||
if tensor_layout == "HND":
|
if tensor_layout == "HND":
|
||||||
if not skip_output_reshape:
|
if not skip_output_reshape:
|
||||||
@@ -648,7 +648,7 @@ except AttributeError as error:
|
|||||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
else:
|
else:
|
||||||
|
Reference in New Issue
Block a user