Progress on state management mocking and hidden values in v3

This commit is contained in:
kosinkadink1@gmail.com 2025-06-16 19:10:51 -07:00
parent 54e0d6b161
commit ef04c46ee3

View File

@ -17,7 +17,7 @@ from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt,
from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.validation import validate_node_input from comfy_execution.validation import validate_node_input
from comfy_api.v3.io import NodeOutput, ComfyNodeV3 from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden
class ExecutionResult(Enum): class ExecutionResult(Enum):
SUCCESS = 0 SUCCESS = 0
@ -49,7 +49,7 @@ class IsChangedCache:
return self.is_changed[node_id] return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None) input_data_all, _, hidden_inputs_v3 = get_input_data(node["inputs"], class_def, node_id, None)
try: try:
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED") is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
@ -110,6 +110,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
valid_inputs = class_def.INPUT_TYPES() valid_inputs = class_def.INPUT_TYPES()
input_data_all = {} input_data_all = {}
missing_keys = {} missing_keys = {}
hidden_inputs_v3 = {}
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
_, input_category, input_info = get_input_info(class_def, x, valid_inputs) _, input_category, input_info = get_input_info(class_def, x, valid_inputs)
@ -134,22 +135,32 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
elif input_category is not None: elif input_category is not None:
input_data_all[x] = [input_data] input_data_all[x] = [input_data]
if "hidden" in valid_inputs: # V3
h = valid_inputs["hidden"] if isinstance(class_def, ComfyNodeV3):
for x in h: hidden_inputs_v3[Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
if h[x] == "PROMPT": hidden_inputs_v3[Hidden.dynprompt] = dynprompt
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}] hidden_inputs_v3[Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
if h[x] == "DYNPROMPT": hidden_inputs_v3[Hidden.unique_id] = unique_id
input_data_all[x] = [dynprompt] hidden_inputs_v3[Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
if h[x] == "EXTRA_PNGINFO": hidden_inputs_v3[Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
input_data_all[x] = [extra_data.get('extra_pnginfo', None)] # V1
if h[x] == "UNIQUE_ID": else:
input_data_all[x] = [unique_id] if "hidden" in valid_inputs:
if h[x] == "AUTH_TOKEN_COMFY_ORG": h = valid_inputs["hidden"]
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] for x in h:
if h[x] == "API_KEY_COMFY_ORG": if h[x] == "PROMPT":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
return input_data_all, missing_keys if h[x] == "DYNPROMPT":
input_data_all[x] = [dynprompt]
if h[x] == "EXTRA_PNGINFO":
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
if h[x] == "UNIQUE_ID":
input_data_all[x] = [unique_id]
if h[x] == "AUTH_TOKEN_COMFY_ORG":
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
return input_data_all, missing_keys, hidden_inputs_v3
map_node_over_list = None #Don't hook this please map_node_over_list = None #Don't hook this please
@ -187,7 +198,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
if isinstance(obj, ComfyNodeV3): if isinstance(obj, ComfyNodeV3):
type(obj).VALIDATE_CLASS() type(obj).VALIDATE_CLASS()
class_clone = type(obj).prepare_class_clone() class_clone = type(obj).prepare_class_clone()
results.append(type(obj).execute.__func__(class_clone, **inputs)) results.append(getattr(type(obj), func).__func__(class_clone, **inputs))
# V1 # V1
else: else:
results.append(getattr(obj, func)(**inputs)) results.append(getattr(obj, func)(**inputs))
@ -251,6 +262,7 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
results.append(result) results.append(result)
subgraph_results.append((None, result)) subgraph_results.append((None, result))
elif isinstance(r, NodeOutput): elif isinstance(r, NodeOutput):
# V3
if r.ui is not None: if r.ui is not None:
uis.append(r.ui.as_dict()) uis.append(r.ui.as_dict())
if r.expand is not None: if r.expand is not None:
@ -329,7 +341,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
output_ui = [] output_ui = []
has_subgraph = False has_subgraph = False
else: else:
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) input_data_all, missing_keys, hidden_inputs_v3 = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
if server.client_id is not None: if server.client_id is not None:
server.last_node_id = display_node_id server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@ -771,7 +783,7 @@ def validate_inputs(prompt, item, validated):
continue continue
if len(validate_function_inputs) > 0 or validate_has_kwargs: if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _ = get_input_data(inputs, obj_class, unique_id) input_data_all, _, hidden_inputs_v3 = get_input_data(inputs, obj_class, unique_id)
input_filtered = {} input_filtered = {}
for x in input_data_all: for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs: if x in validate_function_inputs or validate_has_kwargs: