mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 04:55:53 +00:00
Added hidden and state to passed-in clone of node class
This commit is contained in:
42
execution.py
42
execution.py
@@ -28,7 +28,8 @@ from comfy_execution.graph import (
|
||||
)
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden
|
||||
from comfy_api.v3.io import NodeOutput, ComfyNodeV3, HiddenEnum
|
||||
from comfy_api.v3_01.io import NodeStateLocal
|
||||
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
@@ -61,7 +62,7 @@ class IsChangedCache:
|
||||
return self.is_changed[node_id]
|
||||
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _, hidden_inputs_v3 = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
try:
|
||||
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
@@ -148,13 +149,13 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
# V3
|
||||
if isinstance(class_def, ComfyNodeV3):
|
||||
hidden_inputs_v3[Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
|
||||
hidden_inputs_v3[Hidden.dynprompt] = dynprompt
|
||||
hidden_inputs_v3[Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
|
||||
hidden_inputs_v3[Hidden.unique_id] = unique_id
|
||||
hidden_inputs_v3[Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||
hidden_inputs_v3[Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||
if isinstance(class_def, type(ComfyNodeV3)):
|
||||
hidden_inputs_v3[HiddenEnum.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
|
||||
hidden_inputs_v3[HiddenEnum.dynprompt] = dynprompt
|
||||
hidden_inputs_v3[HiddenEnum.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
|
||||
hidden_inputs_v3[HiddenEnum.unique_id] = unique_id
|
||||
hidden_inputs_v3[HiddenEnum.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||
hidden_inputs_v3[HiddenEnum.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||
# V1
|
||||
else:
|
||||
if "hidden" in valid_inputs:
|
||||
@@ -176,7 +177,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
|
||||
map_node_over_list = None #Don't hook this please
|
||||
|
||||
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
# check if node wants the lists
|
||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||
|
||||
@@ -209,7 +210,12 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
||||
# V3
|
||||
if isinstance(obj, ComfyNodeV3):
|
||||
type(obj).VALIDATE_CLASS()
|
||||
class_clone = type(obj).prepare_class_clone()
|
||||
class_clone = type(obj).prepare_class_clone(hidden_inputs)
|
||||
# NOTE: this is a mock of state management; for local, just stores NodeStateLocal on node instance
|
||||
if hasattr(obj, "local_state"):
|
||||
if obj.local_state is None:
|
||||
obj.local_state = NodeStateLocal(class_clone.hidden.unique_id)
|
||||
class_clone.state = obj.local_state
|
||||
results.append(getattr(type(obj), func).__func__(class_clone, **inputs))
|
||||
# V1
|
||||
else:
|
||||
@@ -248,11 +254,11 @@ def merge_result_data(results, obj):
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
results = []
|
||||
uis = []
|
||||
subgraph_results = []
|
||||
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||
has_subgraph = False
|
||||
for i in range(len(return_values)):
|
||||
r = return_values[i]
|
||||
@@ -353,7 +359,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
output_ui = []
|
||||
has_subgraph = False
|
||||
else:
|
||||
input_data_all, missing_keys, hidden_inputs_v3 = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
@@ -364,7 +370,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
caches.objects.set(unique_id, obj)
|
||||
|
||||
if hasattr(obj, "check_lazy_status"):
|
||||
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||
x not in input_data_all or x in missing_keys
|
||||
@@ -394,7 +400,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
return block
|
||||
def pre_execute_cb(call_index):
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||
if len(output_ui) > 0:
|
||||
caches.ui.set(unique_id, {
|
||||
"meta": {
|
||||
@@ -795,7 +801,7 @@ def validate_inputs(prompt, item, validated):
|
||||
continue
|
||||
|
||||
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||||
input_data_all, _, hidden_inputs_v3 = get_input_data(inputs, obj_class, unique_id)
|
||||
input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
|
||||
input_filtered = {}
|
||||
for x in input_data_all:
|
||||
if x in validate_function_inputs or validate_has_kwargs:
|
||||
@@ -804,7 +810,7 @@ def validate_inputs(prompt, item, validated):
|
||||
input_filtered['input_types'] = [received_types]
|
||||
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS", hidden_inputs=hidden_inputs)
|
||||
for x in input_filtered:
|
||||
for i, r in enumerate(ret):
|
||||
if r is not True and not isinstance(r, ExecutionBlocker):
|
||||
|
Reference in New Issue
Block a user