mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 05:25:23 +00:00
Use the lowvram cast_to function for everything.
This commit is contained in:
@@ -840,27 +840,21 @@ def force_channels_last():
|
||||
#TODO
|
||||
return False
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
return r
|
||||
|
||||
def cast_to_device(tensor, device, dtype, copy=False):
|
||||
device_supports_cast = False
|
||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||
device_supports_cast = True
|
||||
elif tensor.dtype == torch.bfloat16:
|
||||
if hasattr(device, 'type') and device.type.startswith("cuda"):
|
||||
device_supports_cast = True
|
||||
elif is_intel_xpu():
|
||||
device_supports_cast = True
|
||||
non_blocking = device_supports_non_blocking(device)
|
||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
non_blocking = device_should_use_non_blocking(device)
|
||||
|
||||
if device_supports_cast:
|
||||
if copy:
|
||||
if tensor.device == device:
|
||||
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
|
||||
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
||||
else:
|
||||
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
||||
else:
|
||||
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
||||
|
||||
def xformers_enabled():
|
||||
global directml_enabled
|
||||
|
Reference in New Issue
Block a user