Add --use-flash-attention flag. (#7223)

* Add --use-flash-attention flag.
This is useful on AMD systems, as FA builds are still 10% faster than Pytorch cross-attention.
This commit is contained in:
FeepingCreature
2025-03-14 08:22:41 +01:00
committed by GitHub
parent 35504e2f93
commit 7aceb9f91c
3 changed files with 64 additions and 0 deletions

View File

@@ -930,6 +930,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
def sage_attention_enabled():
return args.use_sage_attention
def flash_attention_enabled():
return args.use_flash_attention
def xformers_enabled():
global directml_enabled
global cpu_state