Speedup on some models by not upcasting bfloat16 to float32 on mac.

This commit is contained in:
comfyanonymous
2025-02-24 05:41:07 -05:00
parent 4553891bbd
commit 96d891cb94
2 changed files with 8 additions and 7 deletions

View File

@@ -954,7 +954,7 @@ def force_upcast_attention_dtype():
upcast = True
if upcast:
return torch.float32
return {torch.float16: torch.float32}
else:
return None