Add experimental --async-offload lowvram weight offloading. (#7820)

This should speed up the lowvram mode a bit. It currently is only enabled when --async-offload is used but it will be enabled by default in the future if there are no problems.
This commit is contained in:
comfyanonymous
2025-04-26 13:11:21 -07:00
committed by GitHub
parent b685b8a4e0
commit 0dcc75ca54
3 changed files with 50 additions and 5 deletions

View File

@@ -939,15 +939,56 @@ def force_channels_last():
#TODO
return False
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
STREAMS = {}
NUM_STREAMS = 1
if args.async_offload:
NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
stream_counter = 0
def get_offload_stream(device):
global stream_counter
if NUM_STREAMS <= 1:
return None
if device in STREAMS:
ss = STREAMS[device]
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
return s
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=10))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
return s
return None
def sync_stream(device, stream):
if stream is None:
return
if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
if stream is not None:
with stream:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
else:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r
def cast_to_device(tensor, device, dtype, copy=False):