mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-08 15:17:14 +00:00
Per device stream counters for async offload. (#7873)
This commit is contained in:
parent
5c5457a4ef
commit
0a66d4b0af
@ -946,9 +946,9 @@ if args.async_offload:
|
|||||||
NUM_STREAMS = 2
|
NUM_STREAMS = 2
|
||||||
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
||||||
|
|
||||||
stream_counter = 0
|
stream_counters = {}
|
||||||
def get_offload_stream(device):
|
def get_offload_stream(device):
|
||||||
global stream_counter
|
stream_counter = stream_counters.get(device, 0)
|
||||||
if NUM_STREAMS <= 1:
|
if NUM_STREAMS <= 1:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -958,6 +958,7 @@ def get_offload_stream(device):
|
|||||||
stream_counter = (stream_counter + 1) % len(ss)
|
stream_counter = (stream_counter + 1) % len(ss)
|
||||||
if is_device_cuda(device):
|
if is_device_cuda(device):
|
||||||
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
||||||
|
stream_counters[device] = stream_counter
|
||||||
return s
|
return s
|
||||||
elif is_device_cuda(device):
|
elif is_device_cuda(device):
|
||||||
ss = []
|
ss = []
|
||||||
@ -966,6 +967,7 @@ def get_offload_stream(device):
|
|||||||
STREAMS[device] = ss
|
STREAMS[device] = ss
|
||||||
s = ss[stream_counter]
|
s = ss[stream_counter]
|
||||||
stream_counter = (stream_counter + 1) % len(ss)
|
stream_counter = (stream_counter + 1) % len(ss)
|
||||||
|
stream_counters[device] = stream_counter
|
||||||
return s
|
return s
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user