Progress on state management mocking and hidden values in v3

This commit is contained in:
kosinkadink1@gmail.com 2025-06-16 19:10:51 -07:00
parent 54e0d6b161
commit ef04c46ee3

View File

@ -17,7 +17,7 @@ from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt,
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.validation import validate_node_input
from comfy_api.v3.io import NodeOutput, ComfyNodeV3
from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden
class ExecutionResult(Enum):
SUCCESS = 0
@ -49,7 +49,7 @@ class IsChangedCache:
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
input_data_all, _, hidden_inputs_v3 = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
@ -110,6 +110,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
missing_keys = {}
hidden_inputs_v3 = {}
for x in inputs:
input_data = inputs[x]
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
@ -134,6 +135,16 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
elif input_category is not None:
input_data_all[x] = [input_data]
# V3
if isinstance(class_def, ComfyNodeV3):
hidden_inputs_v3[Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
hidden_inputs_v3[Hidden.dynprompt] = dynprompt
hidden_inputs_v3[Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
hidden_inputs_v3[Hidden.unique_id] = unique_id
hidden_inputs_v3[Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
hidden_inputs_v3[Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
# V1
else:
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
for x in h:
@ -149,7 +160,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
return input_data_all, missing_keys
return input_data_all, missing_keys, hidden_inputs_v3
map_node_over_list = None #Don't hook this please
@ -187,7 +198,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
if isinstance(obj, ComfyNodeV3):
type(obj).VALIDATE_CLASS()
class_clone = type(obj).prepare_class_clone()
results.append(type(obj).execute.__func__(class_clone, **inputs))
results.append(getattr(type(obj), func).__func__(class_clone, **inputs))
# V1
else:
results.append(getattr(obj, func)(**inputs))
@ -251,6 +262,7 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
results.append(result)
subgraph_results.append((None, result))
elif isinstance(r, NodeOutput):
# V3
if r.ui is not None:
uis.append(r.ui.as_dict())
if r.expand is not None:
@ -329,7 +341,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
output_ui = []
has_subgraph = False
else:
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
input_data_all, missing_keys, hidden_inputs_v3 = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@ -771,7 +783,7 @@ def validate_inputs(prompt, item, validated):
continue
if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
input_data_all, _, hidden_inputs_v3 = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs: