diff --git a/comfy/model_management.py b/comfy/model_management.py index 78317af3..44aff376 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -946,9 +946,9 @@ if args.async_offload: NUM_STREAMS = 2 logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) -stream_counter = 0 +stream_counters = {} def get_offload_stream(device): - global stream_counter + stream_counter = stream_counters.get(device, 0) if NUM_STREAMS <= 1: return None @@ -958,6 +958,7 @@ def get_offload_stream(device): stream_counter = (stream_counter + 1) % len(ss) if is_device_cuda(device): ss[stream_counter].wait_stream(torch.cuda.current_stream()) + stream_counters[device] = stream_counter return s elif is_device_cuda(device): ss = [] @@ -966,6 +967,7 @@ def get_offload_stream(device): STREAMS[device] = ss s = ss[stream_counter] stream_counter = (stream_counter + 1) % len(ss) + stream_counters[device] = stream_counter return s return None