Only do the cast on the device if the device supports it.

This commit is contained in:
comfyanonymous
2023-09-20 17:52:41 -04:00
parent b92a86d737
commit 1cdfb3dba4
2 changed files with 46 additions and 14 deletions

View File

@@ -481,6 +481,23 @@ def get_autocast_device(dev):
return dev.type
return "cuda"
def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
device_supports_cast = True
elif tensor.dtype == torch.bfloat16:
if hasattr(device, 'type') and device.type.startswith("cuda"):
device_supports_cast = True
if device_supports_cast:
if copy:
if tensor.device == device:
return tensor.to(dtype, copy=copy)
return tensor.to(device, copy=copy).to(dtype)
else:
return tensor.to(device).to(dtype)
else:
return tensor.to(dtype).to(device, copy=copy)
def xformers_enabled():
global directml_enabled