execution: fold in dependency aware caching / Fix --cache-none with loops/lazy etc (Resubmit) (#10440)

* execution: fold in dependency aware caching

This makes --cache-none compatiable with lazy and expanded
subgraphs.

Currently the --cache-none option is powered by the
DependencyAwareCache. The cache attempts to maintain a parallel
copy of the execution list data structure, however it is only
setup once at the start of execution and does not get meaninigful
updates to the execution list.

This causes multiple problems when --cache-none is used with lazy
and expanded subgraphs as the DAC does not accurately update its
copy of the execution data structure.

DAC has an attempt to handle subgraphs ensure_subcache however
this does not accurately connect to nodes outside the subgraph.
The current semantics of DAC are to free a node ASAP after the
dependent nodes are executed.

This means that if a subgraph refs such a node it will be requed
and re-executed by the execution_list but DAC wont see it in
its to-free lists anymore and leak memory.

Rather than try and cover all the cases where the execution list
changes from inside the cache, move the while problem to the
executor which maintains an always up-to-date copy of the wanted
data-structure.

The executor now has a fast-moving run-local cache of its own.
Each _to node has its own mini cache, and the cache is unconditionally
primed at the time of add_strong_link.

add_strong_link is called for all of static workflows, lazy links
and expanded subgraphs so its the singular source of truth for
output dependendencies.

In the case of a cache-hit, the executor cache will hold the non-none
value (it will respect updates if they happen somehow as well).

In the case of a cache-miss, the executor caches a None and will
wait for a notification to update the value when the node completes.

When a node completes execution, it simply releases its mini-cache
and in turn its strong refs on its direct anscestor outputs, allowing
for ASAP freeing (same as the DependencyAwareCache but a little more
automatic).

This now allows for re-implementation of --cache-none with no cache
at all. The dependency aware cache was also observing the dependency
sematics for the objects and UI cache which is not accurate (this
entire logic was always outputs specific).

This also prepares for more complex caching strategies (such as RAM
pressure based caching), where a cache can implement any freeing
strategy completely independently of the DepedancyAwareness
requirement.

* main: re-implement --cache-none as no cache at all

The execution list now tracks the dependency aware caching more
correctly that the DependancyAwareCache.

Change it to a cache that does nothing.

* test_execution: add --cache-none to the test suite

--cache-none is now expected to work universally. Run it through the
full unit test suite. Propagate the server parameterization for whether
or not the server is capabale of caching, so that the minority of tests
that specifically check for cache hits can if else. Hard assert NOT
caching in the else to give some coverage of --cache-none expected
behaviour to not acutally cache.
This commit is contained in:
rattus
2025-10-23 05:49:05 +10:00
committed by GitHub
parent f13cff0be6
commit 4739d7717f
5 changed files with 102 additions and 190 deletions

View File

@@ -265,6 +265,26 @@ class HierarchicalCache(BasicCache):
assert cache is not None
return await cache._ensure_subcache(node_id, children_ids)
class NullCache:
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
pass
def all_node_ids(self):
return []
def clean_unused(self):
pass
def get(self, node_id):
return None
def set(self, node_id, value):
pass
async def ensure_subcache_for(self, node_id, children_ids):
return self
class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100):
super().__init__(key_class)
@@ -316,157 +336,3 @@ class LRUCache(BasicCache):
self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self
class DependencyAwareCache(BasicCache):
"""
A cache implementation that tracks dependencies between nodes and manages
their execution and caching accordingly. It extends the BasicCache class.
Nodes are removed from this cache once all of their descendants have been
executed.
"""
def __init__(self, key_class):
"""
Initialize the DependencyAwareCache.
Args:
key_class: The class used for generating cache keys.
"""
super().__init__(key_class)
self.descendants = {} # Maps node_id -> set of descendant node_ids
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
self.executed_nodes = set() # Tracks nodes that have been executed
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
"""
Clear the entire cache and rebuild the dependency graph.
Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to initialize the cache for.
is_changed_cache: Flag indicating if the cache has changed.
"""
# Clear all existing cache data
self.cache.clear()
self.subcaches.clear()
self.descendants.clear()
self.ancestors.clear()
self.executed_nodes.clear()
# Call the parent method to initialize the cache with the new prompt
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
# Rebuild the dependency graph
self._build_dependency_graph(dynprompt, node_ids)
def _build_dependency_graph(self, dynprompt, node_ids):
"""
Build the dependency graph for all nodes.
Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to build the graph for.
"""
self.descendants.clear()
self.ancestors.clear()
for node_id in node_ids:
self.descendants[node_id] = set()
self.ancestors[node_id] = set()
for node_id in node_ids:
inputs = dynprompt.get_node(node_id)["inputs"]
for input_data in inputs.values():
if is_link(input_data): # Check if the input is a link to another node
ancestor_id = input_data[0]
self.descendants[ancestor_id].add(node_id)
self.ancestors[node_id].add(ancestor_id)
def set(self, node_id, value):
"""
Mark a node as executed and store its value in the cache.
Args:
node_id: The ID of the node to store.
value: The value to store for the node.
"""
self._set_immediate(node_id, value)
self.executed_nodes.add(node_id)
self._cleanup_ancestors(node_id)
def get(self, node_id):
"""
Retrieve the cached value for a node.
Args:
node_id: The ID of the node to retrieve.
Returns:
The cached value for the node.
"""
return self._get_immediate(node_id)
async def ensure_subcache_for(self, node_id, children_ids):
"""
Ensure a subcache exists for a node and update dependencies.
Args:
node_id: The ID of the parent node.
children_ids: List of child node IDs to associate with the parent node.
Returns:
The subcache object for the node.
"""
subcache = await super()._ensure_subcache(node_id, children_ids)
for child_id in children_ids:
self.descendants[node_id].add(child_id)
self.ancestors[child_id].add(node_id)
return subcache
def _cleanup_ancestors(self, node_id):
"""
Check if ancestors of a node can be removed from the cache.
Args:
node_id: The ID of the node whose ancestors are to be checked.
"""
for ancestor_id in self.ancestors.get(node_id, []):
if ancestor_id in self.executed_nodes:
# Remove ancestor if all its descendants have been executed
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
self._remove_node(ancestor_id)
def _remove_node(self, node_id):
"""
Remove a node from the cache.
Args:
node_id: The ID of the node to remove.
"""
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
del self.cache[cache_key]
subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches:
del self.subcaches[subcache_key]
def clean_unused(self):
"""
Clean up unused nodes. This is a no-op for this cache implementation.
"""
pass
def recursive_debug_dump(self):
"""
Dump the cache and dependency graph for debugging.
Returns:
A list containing the cache state and dependency graph.
"""
result = super().recursive_debug_dump()
result.append({
"descendants": self.descendants,
"ancestors": self.ancestors,
"executed_nodes": list(self.executed_nodes),
})
return result

View File

@@ -153,8 +153,9 @@ class TopologicalSort:
continue
_, _, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)
if (include_lazy or not is_lazy):
if not self.is_cached(from_node_id):
node_ids.append(from_node_id)
links.append((from_node_id, from_socket, unique_id))
for link in links:
@@ -194,10 +195,35 @@ class ExecutionList(TopologicalSort):
super().__init__(dynprompt)
self.output_cache = output_cache
self.staged_node_id = None
self.execution_cache = {}
self.execution_cache_listeners = {}
def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None
def cache_link(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache:
self.execution_cache[to_node_id] = {}
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
if not from_node_id in self.execution_cache_listeners:
self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id)
def get_output_cache(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache:
return None
return self.execution_cache[to_node_id].get(from_node_id)
def cache_update(self, node_id, value):
if node_id in self.execution_cache_listeners:
for to_node_id in self.execution_cache_listeners[node_id]:
if to_node_id in self.execution_cache:
self.execution_cache[to_node_id][node_id] = value
def add_strong_link(self, from_node_id, from_socket, to_node_id):
super().add_strong_link(from_node_id, from_socket, to_node_id)
self.cache_link(from_node_id, to_node_id)
async def stage_node_execution(self):
assert self.staged_node_id is None
if self.is_empty():
@@ -277,6 +303,8 @@ class ExecutionList(TopologicalSort):
def complete_node_execution(self):
node_id = self.staged_node_id
self.pop_node(node_id)
self.execution_cache.pop(node_id, None)
self.execution_cache_listeners.pop(node_id, None)
self.staged_node_id = None
def get_nodes_in_cycle(self):

View File

@@ -18,7 +18,7 @@ from comfy_execution.caching import (
BasicCache,
CacheKeySetID,
CacheKeySetInputSignature,
DependencyAwareCache,
NullCache,
HierarchicalCache,
LRUCache,
)
@@ -91,13 +91,13 @@ class IsChangedCache:
class CacheType(Enum):
CLASSIC = 0
LRU = 1
DEPENDENCY_AWARE = 2
NONE = 2
class CacheSet:
def __init__(self, cache_type=None, cache_size=None):
if cache_type == CacheType.DEPENDENCY_AWARE:
self.init_dependency_aware_cache()
if cache_type == CacheType.NONE:
self.init_null_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.LRU:
if cache_size is None:
@@ -120,11 +120,12 @@ class CacheSet:
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# only hold cached items while the decendents have not executed
def init_dependency_aware_cache(self):
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
self.objects = DependencyAwareCache(CacheKeySetID)
def init_null_cache(self):
self.outputs = NullCache()
#The UI cache is expected to be iterable at the end of each workflow
#so it must cache at least a full workflow. Use Heirachical
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = NullCache()
def recursive_debug_dump(self):
result = {
@@ -135,7 +136,7 @@ class CacheSet:
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
is_v3 = issubclass(class_def, _ComfyNodeInternal)
if is_v3:
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
@@ -153,10 +154,10 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
input_unique_id = input_data[0]
output_index = input_data[1]
if outputs is None:
if execution_list is None:
mark_missing()
continue # This might be a lazily-evaluated input
cached_output = outputs.get(input_unique_id)
cached_output = execution_list.get_output_cache(input_unique_id, unique_id)
if cached_output is None:
mark_missing()
continue
@@ -405,6 +406,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, caches.outputs.get(unique_id))
return (ExecutionResult.SUCCESS, None, None)
input_data_all = None
@@ -434,7 +436,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
for r in result:
if is_link(r):
source_node, source_output = r[0], r[1]
node_output = caches.outputs.get(source_node)[source_output]
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output]
for o in node_output:
resolved_output.append(o)
@@ -446,7 +448,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@@ -549,11 +551,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
subcache.clean_unused()
for node_id in new_output_ids:
execution_list.add_node(node_id)
execution_list.cache_link(node_id, unique_id)
for link in new_output_links:
execution_list.add_strong_link(link[0], link[1], unique_id)
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None)
caches.outputs.set(unique_id, output_data)
execution_list.cache_update(unique_id, output_data)
except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")

View File

@@ -173,7 +173,7 @@ def prompt_worker(q, server_instance):
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif args.cache_none:
cache_type = execution.CacheType.DEPENDENCY_AWARE
cache_type = execution.CacheType.NONE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
last_gc_collect = 0

View File

@@ -152,12 +152,12 @@ class TestExecution:
# Initialize server and client
#
@fixture(scope="class", autouse=True, params=[
# (use_lru, lru_size)
(False, 0),
(True, 0),
(True, 100),
{ "extra_args" : [], "should_cache_results" : True },
{ "extra_args" : ["--cache-lru", 0], "should_cache_results" : True },
{ "extra_args" : ["--cache-lru", 100], "should_cache_results" : True },
{ "extra_args" : ["--cache-none"], "should_cache_results" : False },
])
def _server(self, args_pytest, request):
def server(self, args_pytest, request):
# Start server
pargs = [
'python','main.py',
@@ -167,12 +167,10 @@ class TestExecution:
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
'--cpu',
]
use_lru, lru_size = request.param
if use_lru:
pargs += ['--cache-lru', str(lru_size)]
pargs += [ str(param) for param in request.param["extra_args"] ]
print("Running server with args:", pargs) # noqa: T201
p = subprocess.Popen(pargs)
yield
yield request.param
p.kill()
torch.cuda.empty_cache()
@@ -193,7 +191,7 @@ class TestExecution:
return comfy_client
@fixture(scope="class", autouse=True)
def shared_client(self, args_pytest, _server):
def shared_client(self, args_pytest, server):
client = self.start_client(args_pytest["listen"], args_pytest["port"])
yield client
del client
@@ -225,7 +223,7 @@ class TestExecution:
assert result.did_run(mask)
assert result.did_run(lazy_mix)
def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
def test_full_cache(self, client: ComfyClient, builder: GraphBuilder, server):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
@@ -237,9 +235,12 @@ class TestExecution:
client.run(g)
result2 = client.run(g)
for node_id, node in g.nodes.items():
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
if server["should_cache_results"]:
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
else:
assert result2.did_run(node), f"Node {node_id} was cached, but should have been run"
def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder, server):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
@@ -251,8 +252,12 @@ class TestExecution:
client.run(g)
mask.inputs['value'] = 0.4
result2 = client.run(g)
assert not result2.did_run(input1), "Input1 should have been cached"
assert not result2.did_run(input2), "Input2 should have been cached"
if server["should_cache_results"]:
assert not result2.did_run(input1), "Input1 should have been cached"
assert not result2.did_run(input2), "Input2 should have been cached"
else:
assert result2.did_run(input1), "Input1 should have been rerun"
assert result2.did_run(input2), "Input2 should have been rerun"
def test_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder
@@ -411,7 +416,7 @@ class TestExecution:
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
client.run(g)
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder, server):
g = builder
# Creating the nodes in this specific order previously caused a bug
save = g.node("SaveImage")
@@ -427,7 +432,10 @@ class TestExecution:
result3 = client.run(g)
result4 = client.run(g)
assert result1.did_run(is_changed), "is_changed should have been run"
assert not result2.did_run(is_changed), "is_changed should have been cached"
if server["should_cache_results"]:
assert not result2.did_run(is_changed), "is_changed should have been cached"
else:
assert result2.did_run(is_changed), "is_changed should have been re-run"
assert result3.did_run(is_changed), "is_changed should have been re-run"
assert result4.did_run(is_changed), "is_changed should not have been cached"
@@ -514,7 +522,7 @@ class TestExecution:
assert len(images2) == 1, "Should have 1 image"
# This tests that only constant outputs are used in the call to `IS_CHANGED`
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder):
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server):
g = builder
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
@@ -530,7 +538,11 @@ class TestExecution:
images = result.get_images(output)
assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
assert not result.did_run(test_node), "The execution should have been cached"
if server["should_cache_results"]:
assert not result.did_run(test_node), "The execution should have been cached"
else:
assert result.did_run(test_node), "The execution should have been re-run"
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
# Warmup execution to ensure server is fully initialized