diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 86cb3d384..b3bb71734 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -194,7 +194,7 @@ def wrap_attn(func): finally: # Important: break ref cycles so tensors aren't pinned del frame - transformer_options = kwargs.pop("transformer_options", None) + transformer_options = kwargs.get("transformer_options", None) if transformer_options is not None: if "optimized_attention_override" in transformer_options: return transformer_options["optimized_attention_override"](func, transformer_options, *args, **kwargs) @@ -202,7 +202,7 @@ def wrap_attn(func): return wrapper @wrap_attn -def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: @@ -271,7 +271,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape return out @wrap_attn -def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, query.dtype) if skip_reshape: @@ -342,7 +342,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, return hidden_states @wrap_attn -def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: @@ -472,7 +472,7 @@ except: pass @wrap_attn -def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): b = q.shape[0] dim_head = q.shape[-1] # check to make sure xformers isn't broken @@ -487,7 +487,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh disabled_xformers = True if disabled_xformers: - return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) + return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs) if skip_reshape: # b h k d -> b k h d @@ -541,7 +541,7 @@ else: SDP_BATCH_LIMIT = 2**31 @wrap_attn -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape else: @@ -584,7 +584,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha return out @wrap_attn -def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape tensor_layout = "HND" @@ -614,7 +614,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= lambda t: t.transpose(1, 2), (q, k, v), ) - return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape) + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs) if tensor_layout == "HND": if not skip_output_reshape: @@ -648,7 +648,7 @@ except AttributeError as error: assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" @wrap_attn -def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): if skip_reshape: b, _, _, dim_head = q.shape else: