pytorch_attention_enabled can now return True when xformers is enabled.

This commit is contained in:
comfyanonymous
2023-10-11 21:29:03 -04:00
parent 20d3852aa1
commit 88733c997f
2 changed files with 7 additions and 4 deletions

View File

@@ -154,14 +154,18 @@ def is_nvidia():
return True
return False
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILABLE = False
VAE_DTYPE = torch.float32
try:
if is_nvidia():
torch_version = torch.version.__version__
if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_bf16_supported():
VAE_DTYPE = torch.bfloat16
@@ -186,7 +190,6 @@ if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
XFORMERS_IS_AVAILABLE = False
if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM