diff --git a/comfy/model_management.py b/comfy/model_management.py index ef9bec545..0c51eee51 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -553,15 +553,19 @@ def cast_to_device(tensor, device, dtype, copy=False): elif is_intel_xpu(): device_supports_cast = True + non_blocking = True + if is_device_mps(device): + non_blocking = False #pytorch bug? mps doesn't support non blocking + if device_supports_cast: if copy: if tensor.device == device: - return tensor.to(dtype, copy=copy, non_blocking=True) - return tensor.to(device, copy=copy, non_blocking=True).to(dtype, non_blocking=True) + return tensor.to(dtype, copy=copy, non_blocking=non_blocking) + return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking) else: - return tensor.to(device, non_blocking=True).to(dtype, non_blocking=True) + return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking) else: - return tensor.to(device, dtype, copy=copy, non_blocking=True) + return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking) def xformers_enabled(): global directml_enabled