mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 20:17:30 +00:00
Only do the cast on the device if the device supports it.
This commit is contained in:
@@ -481,6 +481,23 @@ def get_autocast_device(dev):
|
||||
return dev.type
|
||||
return "cuda"
|
||||
|
||||
def cast_to_device(tensor, device, dtype, copy=False):
|
||||
device_supports_cast = False
|
||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||
device_supports_cast = True
|
||||
elif tensor.dtype == torch.bfloat16:
|
||||
if hasattr(device, 'type') and device.type.startswith("cuda"):
|
||||
device_supports_cast = True
|
||||
|
||||
if device_supports_cast:
|
||||
if copy:
|
||||
if tensor.device == device:
|
||||
return tensor.to(dtype, copy=copy)
|
||||
return tensor.to(device, copy=copy).to(dtype)
|
||||
else:
|
||||
return tensor.to(device).to(dtype)
|
||||
else:
|
||||
return tensor.to(dtype).to(device, copy=copy)
|
||||
|
||||
def xformers_enabled():
|
||||
global directml_enabled
|
||||
|
Reference in New Issue
Block a user