mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-17 10:24:21 +00:00
Make last PR not crash comfy on old pytorch. (#9324)
This commit is contained in:
parent
3da5a07510
commit
9df8792d4b
@ -178,7 +178,7 @@ class FourierEmbedder(nn.Module):
|
|||||||
|
|
||||||
class CrossAttentionProcessor:
|
class CrossAttentionProcessor:
|
||||||
def __call__(self, attn, q, k, v):
|
def __call__(self, attn, q, k, v):
|
||||||
out = ops.scaled_dot_product_attention(q, k, v)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
mask = mask.unsqueeze(1)
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
if SDP_BATCH_LIMIT >= b:
|
if SDP_BATCH_LIMIT >= b:
|
||||||
out = ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
if not skip_output_reshape:
|
if not skip_output_reshape:
|
||||||
out = (
|
out = (
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
if mask.shape[0] > 1:
|
if mask.shape[0] > 1:
|
||||||
m = mask[i : i + SDP_BATCH_LIMIT]
|
m = mask[i : i + SDP_BATCH_LIMIT]
|
||||||
|
|
||||||
out[i : i + SDP_BATCH_LIMIT] = ops.scaled_dot_product_attention(
|
out[i : i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention(
|
||||||
q[i : i + SDP_BATCH_LIMIT],
|
q[i : i + SDP_BATCH_LIMIT],
|
||||||
k[i : i + SDP_BATCH_LIMIT],
|
k[i : i + SDP_BATCH_LIMIT],
|
||||||
v[i : i + SDP_BATCH_LIMIT],
|
v[i : i + SDP_BATCH_LIMIT],
|
||||||
|
@ -285,7 +285,7 @@ def pytorch_attention(q, k, v):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
out = ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
out = out.transpose(2, 3).reshape(orig_shape)
|
out = out.transpose(2, 3).reshape(orig_shape)
|
||||||
except model_management.OOM_EXCEPTION:
|
except model_management.OOM_EXCEPTION:
|
||||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||||
|
24
comfy/ops.py
24
comfy/ops.py
@ -23,18 +23,32 @@ from comfy.cli_args import args, PerformanceFeature
|
|||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.rmsnorm
|
import comfy.rmsnorm
|
||||||
import contextlib
|
import contextlib
|
||||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
||||||
|
|
||||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
|
||||||
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||||
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||||
|
|
||||||
SDPA_BACKEND_PRIORITY = [
|
SDPA_BACKEND_PRIORITY = [
|
||||||
SDPBackend.FLASH_ATTENTION,
|
SDPBackend.FLASH_ATTENTION,
|
||||||
SDPBackend.EFFICIENT_ATTENTION,
|
SDPBackend.EFFICIENT_ATTENTION,
|
||||||
SDPBackend.MATH,
|
SDPBackend.MATH,
|
||||||
]
|
]
|
||||||
if torch.cuda.is_available():
|
|
||||||
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
||||||
|
|
||||||
|
@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
|
||||||
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||||
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||||
|
except (ModuleNotFoundError, TypeError):
|
||||||
|
logging.warning("Could not set sdpa backend priority.")
|
||||||
|
|
||||||
|
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||||
|
|
||||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
@ -258,10 +272,6 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
|
|
||||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
|
||||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
|
||||||
|
|
||||||
class manual_cast(disable_weight_init):
|
class manual_cast(disable_weight_init):
|
||||||
class Linear(disable_weight_init.Linear):
|
class Linear(disable_weight_init.Linear):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user