Add pytorch attention support to VAE.

This commit is contained in:
comfyanonymous
2023-03-13 12:25:19 -04:00
parent a256a2abde
commit 83f23f82b8
4 changed files with 83 additions and 13 deletions

View File

@@ -479,23 +479,19 @@ class CrossAttentionPytorch(nn.Module):
return self.to_out(out)
import sys
if model_management.xformers_enabled() == False:
if model_management.xformers_enabled():
print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention
elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention")
CrossAttention = CrossAttentionPytorch
else:
if "--use-split-cross-attention" in sys.argv:
print("Using split optimization for cross attention")
CrossAttention = CrossAttentionDoggettx
else:
if "--use-pytorch-cross-attention" in sys.argv:
print("Using pytorch cross attention")
torch.backends.cuda.enable_math_sdp(False)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
CrossAttention = CrossAttentionPytorch
else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
CrossAttention = CrossAttentionBirchSan
else:
print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
CrossAttention = CrossAttentionBirchSan
class BasicTransformerBlock(nn.Module):