diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 309240d5..fd8888d0 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -351,8 +351,11 @@ else: optimized_attention_masked = optimized_attention def optimized_attention_for_device(device, mask=False, small_input=False): - if small_input and model_management.pytorch_attention_enabled(): - return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases + if small_input: + if model_management.pytorch_attention_enabled(): + return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases + else: + return attention_basic if device == torch.device("cpu"): return attention_sub_quad