This commit is contained in:
comfyanonymous
2023-09-23 00:57:17 -04:00

View File

@@ -488,6 +488,8 @@ def cast_to_device(tensor, device, dtype, copy=False):
elif tensor.dtype == torch.bfloat16:
if hasattr(device, 'type') and device.type.startswith("cuda"):
device_supports_cast = True
elif is_intel_xpu():
device_supports_cast = True
if device_supports_cast:
if copy: