From 38721fdb64c6aaeeb0349de6929e67ee51ff9e32 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 17 Jun 2025 20:35:32 -0500 Subject: [PATCH] Added hidden and state to passed-in clone of node class --- comfy_api/v3/io.py | 41 +++++++++++++++-- comfy_api/v3_01/io.py | 79 ++++++++++++++++++++++++++------ comfy_extras/nodes_v3_01_test.py | 16 +++++++ comfy_extras/nodes_v3_test.py | 2 +- execution.py | 42 +++++++++-------- 5 files changed, 143 insertions(+), 37 deletions(-) diff --git a/comfy_api/v3/io.py b/comfy_api/v3/io.py index f3ae110e6..9e146d7ae 100644 --- a/comfy_api/v3/io.py +++ b/comfy_api/v3/io.py @@ -549,7 +549,40 @@ class ComboDynamicInput(DynamicInput, io_type="COMFY_COMBODYNAMIC_V3"): AutoGrowDynamicInput(id="dynamic", template_input=ImageInput(id="image")) -class Hidden(str, Enum): +class Hidden: + def __init__(self, unique_id: str, prompt: Any, + extra_pnginfo: Any, dynprompt: Any, + auth_token_comfy_org: str, api_key_comfy_org: str, **kwargs): + self.unique_id = unique_id + """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" + self.prompt = prompt + """PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description.""" + self.extra_pnginfo = extra_pnginfo + """EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node).""" + self.dynprompt = dynprompt + """DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion.""" + self.auth_token_comfy_org = auth_token_comfy_org + """AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend.""" + self.api_key_comfy_org = api_key_comfy_org + """API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend.""" + + def __getattr__(self, key: str): + '''If hidden variable not found, return None.''' + return None + + @classmethod + def from_dict(cls, d: dict): + return cls( + unique_id=d.get(HiddenEnum.unique_id), + prompt=d.get(HiddenEnum.prompt), + extra_pnginfo=d.get(HiddenEnum.extra_pnginfo), + dynprompt=d.get(HiddenEnum.dynprompt), + auth_token_comfy_org=d.get(HiddenEnum.auth_token_comfy_org), + api_key_comfy_org=d.get(HiddenEnum.api_key_comfy_org), + ) + + +class HiddenEnum(str, Enum): ''' Enumerator for requesting hidden variables in nodes. ''' @@ -607,7 +640,7 @@ class SchemaV3: """The category of the node, as per the "Add Node" menu.""" inputs: list[InputV3]=None outputs: list[OutputV3]=None - hidden: list[Hidden]=None + hidden: list[HiddenEnum]=None description: str="" """Node description, shown as a tooltip when hovering over the node.""" is_input_list: bool = False @@ -772,7 +805,7 @@ class ComfyNodeV3(ABC): raise Exception(f"No execute function was defined for node class {cls.__name__}.") @classmethod - def prepare_class_clone(cls) -> type[ComfyNodeV3]: + def prepare_class_clone(cls, hidden_inputs: dict, *args, **kwargs) -> type[ComfyNodeV3]: """Creates clone of real node class to prevent monkey-patching.""" c_type: type[ComfyNodeV3] = cls if is_class(cls) else type(cls) type_clone: type[ComfyNodeV3] = type(f"CLEAN_{c_type.__name__}", c_type.__bases__, {}) @@ -1063,7 +1096,7 @@ class TestNode(ComfyNodeV3): MaskInput("thing"), ], outputs=[ImageOutput("image_output")], - hidden=[Hidden.api_key_comfy_org, Hidden.auth_token_comfy_org, Hidden.unique_id] + hidden=[HiddenEnum.api_key_comfy_org, HiddenEnum.auth_token_comfy_org, HiddenEnum.unique_id] ) @classmethod diff --git a/comfy_api/v3_01/io.py b/comfy_api/v3_01/io.py index 07be5c14e..fa0e30489 100644 --- a/comfy_api/v3_01/io.py +++ b/comfy_api/v3_01/io.py @@ -213,6 +213,29 @@ class ComfyTypeIO(ComfyType): ... +class NodeState: + def __init__(self, node_id: str): + self.node_id = node_id + + +class NodeStateLocal(NodeState): + def __init__(self, node_id: str): + super().__init__(node_id) + self.local_state = {} + + def __getattr__(self, key: str): + local_state = type(self).__getattribute__(self, "local_state") + if key in local_state: + return local_state[key] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'") + + def __setattr__(self, key: str, value: Any): + if key in ['node_id', 'local_state']: + super().__setattr__(key, value) + else: + self.local_state[key] = value + + @comfytype(io_type=IO.BOOLEAN) class Boolean: Type = bool @@ -528,7 +551,39 @@ class ComboDynamicInput(DynamicInput): AutoGrowDynamicInput(id="dynamic", template_input=Image.Input(id="image")) -class Hidden(str, Enum): +class Hidden: + def __init__(self, unique_id: str, prompt: Any, + extra_pnginfo: Any, dynprompt: Any, + auth_token_comfy_org: str, api_key_comfy_org: str, **kwargs): + self.unique_id = unique_id + """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" + self.prompt = prompt + """PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description.""" + self.extra_pnginfo = extra_pnginfo + """EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node).""" + self.dynprompt = dynprompt + """DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion.""" + self.auth_token_comfy_org = auth_token_comfy_org + """AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend.""" + self.api_key_comfy_org = api_key_comfy_org + """API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend.""" + + def __getattr__(self, key: str): + '''If hidden variable not found, return None.''' + return None + + @classmethod + def from_dict(cls, d: dict): + return cls( + unique_id=d.get(HiddenEnum.unique_id), + prompt=d.get(HiddenEnum.prompt), + extra_pnginfo=d.get(HiddenEnum.extra_pnginfo), + dynprompt=d.get(HiddenEnum.dynprompt), + auth_token_comfy_org=d.get(HiddenEnum.auth_token_comfy_org), + api_key_comfy_org=d.get(HiddenEnum.api_key_comfy_org), + ) + +class HiddenEnum(str, Enum): ''' Enumerator for requesting hidden variables in nodes. ''' @@ -585,7 +640,7 @@ class SchemaV3: """The category of the node, as per the "Add Node" menu.""" inputs: list[InputV3]=None outputs: list[OutputV3]=None - hidden: list[Hidden]=None + hidden: list[HiddenEnum]=None description: str="" """Node description, shown as a tooltip when hovering over the node.""" is_input_list: bool = False @@ -691,16 +746,6 @@ class Serializer: pass -def prepare_class_clone(c: ComfyNodeV3 | type[ComfyNodeV3]) -> type[ComfyNodeV3]: - """Creates clone of real node class to prevent monkey-patching.""" - c_type: type[ComfyNodeV3] = c if is_class(c) else type(c) - type_clone: type[ComfyNodeV3] = type(f"CLEAN_{c_type.__name__}", c_type.__bases__, {}) - # TODO: what parameters should be carried over? - type_clone.SCHEMA = c_type.SCHEMA - # TODO: add anything we would want to expose inside node's execute function - return type_clone - - class classproperty(object): def __init__(self, f): self.f = f @@ -713,6 +758,10 @@ class ComfyNodeV3(BASE_CV3): RELATIVE_PYTHON_MODULE = None SCHEMA = None + + # filled in during execution + state: NodeState = None + hidden: Hidden = None @classmethod def GET_NODE_INFO_V3(cls) -> dict[str, Any]: @@ -740,6 +789,7 @@ class ComfyNodeV3(BASE_CV3): return [] def __init__(self): + self.local_state: NodeStateLocal = None self.__class__.VALIDATE_CLASS() @classmethod @@ -750,12 +800,13 @@ class ComfyNodeV3(BASE_CV3): raise Exception(f"No execute function was defined for node class {cls.__name__}.") @classmethod - def prepare_class_clone(cls) -> type[ComfyNodeV3]: + def prepare_class_clone(cls, hidden_inputs: dict, *args, **kwargs) -> type[ComfyNodeV3]: """Creates clone of real node class to prevent monkey-patching.""" c_type: type[ComfyNodeV3] = cls if is_class(cls) else type(cls) type_clone: type[ComfyNodeV3] = type(f"CLEAN_{c_type.__name__}", c_type.__bases__, {}) # TODO: what parameters should be carried over? type_clone.SCHEMA = c_type.SCHEMA + type_clone.hidden = Hidden.from_dict(hidden_inputs) # TODO: add anything we would want to expose inside node's execute function return type_clone @@ -1040,7 +1091,7 @@ class TestNode(ComfyNodeV3): Mask.Input("thing"), ], outputs=[Image.Output("image_output")], - hidden=[Hidden.api_key_comfy_org, Hidden.auth_token_comfy_org, Hidden.unique_id] + hidden=[HiddenEnum.api_key_comfy_org, HiddenEnum.auth_token_comfy_org, HiddenEnum.unique_id] ) @classmethod diff --git a/comfy_extras/nodes_v3_01_test.py b/comfy_extras/nodes_v3_01_test.py index 22fea4744..10821a72d 100644 --- a/comfy_extras/nodes_v3_01_test.py +++ b/comfy_extras/nodes_v3_01_test.py @@ -11,10 +11,17 @@ class XYZ: class Output(io.OutputV3): ... +class MyState(io.NodeState): + my_str: str + my_int: int + class V3TestNode(io.ComfyNodeV3): + state: MyState + def __init__(self): + super().__init__() self.hahajkunless = ";)" @classmethod @@ -64,6 +71,15 @@ class V3TestNode(io.ComfyNodeV3): @classmethod def execute(cls, image: io.Image.Type, some_int: int, combo: io.Combo.Type, combo2: io.MultiCombo.Type, xyz: XYZ.Type=None, mask: io.Mask.Type=None, **kwargs): + zzz = cls.hidden.prompt + cls.state.my_str = "LOLJK" + expected_int = 123 + cls.state.my_int = expected_int + if cls.state.my_int is None: + cls.state.my_int = expected_int + else: + if cls.state.my_int != expected_int: + raise Exception(f"Explicit state object did not maintain expected value: {cls.state.my_int} != {expected_int}") #some_int if hasattr(cls, "hahajkunless"): raise Exception("The 'cls' variable leaked instance state between runs!") diff --git a/comfy_extras/nodes_v3_test.py b/comfy_extras/nodes_v3_test.py index 641119376..b6edbca8c 100644 --- a/comfy_extras/nodes_v3_test.py +++ b/comfy_extras/nodes_v3_test.py @@ -3,7 +3,7 @@ from comfy_api.v3.io import ( ComfyNodeV3, SchemaV3, NumberDisplay, IntegerInput, MaskInput, ImageInput, ComboInput, CustomInput, StringInput, CustomType, IntegerOutput, ImageOutput, MultitypedInput, InputV3, OutputV3, - NodeOutput, Hidden + NodeOutput ) import logging diff --git a/execution.py b/execution.py index 6b885e476..d472729b0 100644 --- a/execution.py +++ b/execution.py @@ -28,7 +28,8 @@ from comfy_execution.graph import ( ) from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.validation import validate_node_input -from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden +from comfy_api.v3.io import NodeOutput, ComfyNodeV3, HiddenEnum +from comfy_api.v3_01.io import NodeStateLocal class ExecutionResult(Enum): @@ -61,7 +62,7 @@ class IsChangedCache: return self.is_changed[node_id] # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED - input_data_all, _, hidden_inputs_v3 = get_input_data(node["inputs"], class_def, node_id, None) + input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None) try: is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED") node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] @@ -148,13 +149,13 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e input_data_all[x] = [input_data] # V3 - if isinstance(class_def, ComfyNodeV3): - hidden_inputs_v3[Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {} - hidden_inputs_v3[Hidden.dynprompt] = dynprompt - hidden_inputs_v3[Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None) - hidden_inputs_v3[Hidden.unique_id] = unique_id - hidden_inputs_v3[Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None) - hidden_inputs_v3[Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None) + if isinstance(class_def, type(ComfyNodeV3)): + hidden_inputs_v3[HiddenEnum.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {} + hidden_inputs_v3[HiddenEnum.dynprompt] = dynprompt + hidden_inputs_v3[HiddenEnum.extra_pnginfo] = extra_data.get('extra_pnginfo', None) + hidden_inputs_v3[HiddenEnum.unique_id] = unique_id + hidden_inputs_v3[HiddenEnum.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None) + hidden_inputs_v3[HiddenEnum.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None) # V1 else: if "hidden" in valid_inputs: @@ -176,7 +177,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e map_node_over_list = None #Don't hook this please -def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): +def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): # check if node wants the lists input_is_list = getattr(obj, "INPUT_IS_LIST", False) @@ -209,7 +210,12 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut # V3 if isinstance(obj, ComfyNodeV3): type(obj).VALIDATE_CLASS() - class_clone = type(obj).prepare_class_clone() + class_clone = type(obj).prepare_class_clone(hidden_inputs) + # NOTE: this is a mock of state management; for local, just stores NodeStateLocal on node instance + if hasattr(obj, "local_state"): + if obj.local_state is None: + obj.local_state = NodeStateLocal(class_clone.hidden.unique_id) + class_clone.state = obj.local_state results.append(getattr(type(obj), func).__func__(class_clone, **inputs)) # V1 else: @@ -248,11 +254,11 @@ def merge_result_data(results, obj): output.append([o[i] for o in results]) return output -def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): +def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): results = [] uis = [] subgraph_results = [] - return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) has_subgraph = False for i in range(len(return_values)): r = return_values[i] @@ -353,7 +359,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp output_ui = [] has_subgraph = False else: - input_data_all, missing_keys, hidden_inputs_v3 = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) + input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -364,7 +370,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp caches.objects.set(unique_id, obj) if hasattr(obj, "check_lazy_status"): - required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True) + required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = [x for x in required_inputs if isinstance(x,str) and ( x not in input_data_all or x in missing_keys @@ -394,7 +400,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp return block def pre_execute_cb(call_index): GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) if len(output_ui) > 0: caches.ui.set(unique_id, { "meta": { @@ -795,7 +801,7 @@ def validate_inputs(prompt, item, validated): continue if len(validate_function_inputs) > 0 or validate_has_kwargs: - input_data_all, _, hidden_inputs_v3 = get_input_data(inputs, obj_class, unique_id) + input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs or validate_has_kwargs: @@ -804,7 +810,7 @@ def validate_inputs(prompt, item, validated): input_filtered['input_types'] = [received_types] #ret = obj_class.VALIDATE_INPUTS(**input_filtered) - ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") + ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS", hidden_inputs=hidden_inputs) for x in input_filtered: for i, r in enumerate(ret): if r is not True and not isinstance(r, ExecutionBlocker):