From 0d185b721fbd6ccad57fc12d7f112a6698077448 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 1 Jun 2025 01:08:07 -0700 Subject: [PATCH] Created and handled NodeOutput class to be the return value of v3 nodes' execute function --- comfy_api/v3/io.py | 150 ++++++++++++++++++++++++++-------- comfy_extras/nodes_v3_test.py | 15 +++- execution.py | 17 ++++ 3 files changed, 145 insertions(+), 37 deletions(-) diff --git a/comfy_api/v3/io.py b/comfy_api/v3/io.py index c49ea434..3867f9c6 100644 --- a/comfy_api/v3/io.py +++ b/comfy_api/v3/io.py @@ -1,13 +1,14 @@ from __future__ import annotations -from typing import Union, Any +from typing import Any, Literal from enum import Enum from abc import ABC, abstractmethod from dataclasses import dataclass, asdict +from comfy.comfy_types.node_typing import IO + class InputBehavior(str, Enum): required = "required" optional = "optional" -# TODO: handle hidden inputs def is_class(obj): @@ -30,7 +31,7 @@ class IO_V3: def __init__(self): pass - def __init_subclass__(cls, io_type, **kwargs): + def __init_subclass__(cls, io_type: IO | str, **kwargs): cls.io_type = io_type super().__init_subclass__(**kwargs) @@ -75,11 +76,11 @@ class WidgetInputV3(InputV3, io_type=None): "widgetType": self.widgetType, }) -def CustomType(io_type: str) -> type[IO_V3]: +def CustomType(io_type: IO | str) -> type[IO_V3]: name = f"{io_type}_IO_V3" return type(name, (IO_V3,), {}, io_type=io_type) -def CustomInput(id: str, io_type: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None) -> InputV3: +def CustomInput(id: str, io_type: IO | str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None) -> InputV3: ''' Defines input for 'io_type'. Can be used to stand in for non-core types. ''' @@ -92,7 +93,7 @@ def CustomInput(id: str, io_type: str, display_name: str=None, behavior=InputBeh } return type(f"{io_type}Input", (InputV3,), {}, io_type=io_type)(**input_kwargs) -def CustomOutput(id: str, io_type: str, display_name: str=None, tooltip: str=None) -> OutputV3: +def CustomOutput(id: str, io_type: IO | str, display_name: str=None, tooltip: str=None) -> OutputV3: ''' Defines output for 'io_type'. Can be used to stand in for non-core types. ''' @@ -104,7 +105,7 @@ def CustomOutput(id: str, io_type: str, display_name: str=None, tooltip: str=Non return type(f"{io_type}Output", (OutputV3,), {}, io_type=io_type)(**input_kwargs) -class BooleanInput(WidgetInputV3, io_type="BOOLEAN"): +class BooleanInput(WidgetInputV3, io_type=IO.BOOLEAN): ''' Boolean input. ''' @@ -122,7 +123,7 @@ class BooleanInput(WidgetInputV3, io_type="BOOLEAN"): "label_off": self.label_off, }) -class IntegerInput(WidgetInputV3, io_type="INT"): +class IntegerInput(WidgetInputV3, io_type=IO.INT): ''' Integer input. ''' @@ -146,7 +147,7 @@ class IntegerInput(WidgetInputV3, io_type="INT"): "display": self.display_mode, # NOTE: in frontend, the parameter is called "display" }) -class FloatInput(WidgetInputV3, io_type="FLOAT"): +class FloatInput(WidgetInputV3, io_type=IO.FLOAT): ''' Float input. ''' @@ -171,7 +172,7 @@ class FloatInput(WidgetInputV3, io_type="FLOAT"): "display": self.display_mode, # NOTE: in frontend, the parameter is called "display" }) -class StringInput(WidgetInputV3, io_type="STRING"): +class StringInput(WidgetInputV3, io_type=IO.STRING): ''' String input. ''' @@ -189,7 +190,7 @@ class StringInput(WidgetInputV3, io_type="STRING"): "placeholder": self.placeholder, }) -class ComboInput(WidgetInputV3, io_type="COMBO"): +class ComboInput(WidgetInputV3, io_type=IO.COMBO): '''Combo input (dropdown).''' def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None, default: str=None, control_after_generate: bool=None, @@ -207,7 +208,7 @@ class ComboInput(WidgetInputV3, io_type="COMBO"): "control_after_generate": self.control_after_generate, }) -class MultiselectComboWidget(ComboInput, io_type="COMBO"): +class MultiselectComboWidget(ComboInput, io_type=IO.COMBO): '''Multiselect Combo input (dropdown for selecting potentially more than one value).''' def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None, default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, @@ -225,21 +226,21 @@ class MultiselectComboWidget(ComboInput, io_type="COMBO"): "chip": self.chip, }) -class ImageInput(InputV3, io_type="IMAGE"): +class ImageInput(InputV3, io_type=IO.IMAGE): ''' Image input. ''' def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None): super().__init__(id, display_name, behavior, tooltip) -class MaskInput(InputV3, io_type="MASK"): +class MaskInput(InputV3, io_type=IO.MASK): ''' Mask input. ''' def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None): super().__init__(id, display_name, behavior, tooltip) -class LatentInput(InputV3, io_type="LATENT"): +class LatentInput(InputV3, io_type=IO.LATENT): ''' Latent input. ''' @@ -250,7 +251,7 @@ class MultitypedInput(InputV3, io_type="COMFY_MULTITYPED_V3"): ''' Input that permits more than one input type. ''' - def __init__(self, id: str, io_types: list[Union[type[IO_V3], InputV3, str]], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None,): + def __init__(self, id: str, io_types: list[type[IO_V3] | InputV3 | IO |str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None,): super().__init__(id, display_name, behavior, tooltip) self._io_types = io_types @@ -283,24 +284,24 @@ class OutputV3: cls.io_type = io_type super().__init_subclass__(**kwargs) -class IntegerOutput(OutputV3, io_type="INT"): +class IntegerOutput(OutputV3, io_type=IO.INT): pass -class FloatOutput(OutputV3, io_type="FLOAT"): +class FloatOutput(OutputV3, io_type=IO.FLOAT): pass -class StringOutput(OutputV3, io_type="STRING"): +class StringOutput(OutputV3, io_type=IO.STRING): pass # def __init__(self, id: str, display_name: str=None, tooltip: str=None): # super().__init__(id, display_name, tooltip) -class ImageOutput(OutputV3, io_type="IMAGE"): +class ImageOutput(OutputV3, io_type=IO.IMAGE): pass -class MaskOutput(OutputV3, io_type="MASK"): +class MaskOutput(OutputV3, io_type=IO.MASK): pass -class LatentOutput(OutputV3, io_type="LATENT"): +class LatentOutput(OutputV3, io_type=IO.LATENT): pass @@ -675,6 +676,12 @@ class ComfyNodeV3(ABC): #-------------------------------------------- ############################################# + @classmethod + def GET_NODE_INFO_V3(cls) -> dict[str, Any]: + schema = cls.GET_SCHEMA() + # TODO: finish + return None + @classmethod @abstractmethod @@ -690,26 +697,103 @@ class ComfyNodeV3(ABC): raise Exception("No DEFINE_SCHEMA function was defined for this node.") @abstractmethod - def execute(self, inputs, outputs, hidden, **kwargs): + def execute(self, **kwargs) -> NodeOutput: pass -class ReturnedInputs: +# class ReturnedInputs: +# def __init__(self): +# pass + +# class ReturnedOutputs: +# def __init__(self): +# pass + + +class NodeOutput: + ''' + Standardized output of a node; can pass in any number of args and/or a UIOutput into 'ui' kwarg. + ''' + def __init__(self, *args: Any, ui: UIOutput | dict=None, expand: dict=None, block_execution: str=None, **kwargs): + self.args = args + self.ui = ui + self.expand = expand + self.block_execution = block_execution + + @property + def result(self): + return self.args if len(self.args) > 0 else None + + +class SavedResult: + def __init__(self, filename: str, subfolder: str, type: Literal["input", "output", "temp"]): + self.filename = filename + self.subfolder = subfolder + self.type = type + + def as_dict(self): + return { + "filename": self.filename, + "subfolder": self.subfolder, + "type": self.type + } + +class UIOutput(ABC): def __init__(self): pass -class ReturnedOutputs: - def __init__(self): - pass + @abstractmethod + def as_dict(self) -> dict: + ... # TODO: finish +class UIImages(UIOutput): + def __init__(self, values: list[SavedResult | dict], animated=False, **kwargs): + self.values = values + self.animated = animated + + def as_dict(self): + values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values] + return { + "images": values, + "animated": (self.animated,) + } -class NodeOutputV3: - def __init__(self): - pass +class UILatents(UIOutput): + def __init__(self, values: list[SavedResult | dict], **kwargs): + self.values = values + + def as_dict(self): + values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values] + return { + "latents": values, + } -class UINodeOutput: - def __init__(self): - pass +class UIAudio(UIOutput): + def __init__(self, values: list[SavedResult | dict], **kwargs): + self.values = values + + def as_dict(self): + values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values] + return { + "audio": values, + } + +class UI3D(UIOutput): + def __init__(self, values: list[SavedResult | dict], **kwargs): + self.values = values + + def as_dict(self): + values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values] + return { + "3d": values, + } + +class UIText(UIOutput): + def __init__(self, value: str, **kwargs): + self.value = value + + def as_dict(self): + return {"text": (self.value,)} class TestNode(ComfyNodeV3): diff --git a/comfy_extras/nodes_v3_test.py b/comfy_extras/nodes_v3_test.py index 76bb0c67..fdcae952 100644 --- a/comfy_extras/nodes_v3_test.py +++ b/comfy_extras/nodes_v3_test.py @@ -2,7 +2,7 @@ import torch from comfy_api.v3.io import ( ComfyNodeV3, SchemaV3, CustomType, CustomInput, CustomOutput, InputBehavior, NumberDisplay, - IntegerInput, MaskInput, ImageInput, ComboDynamicInput, + IntegerInput, MaskInput, ImageInput, ComboDynamicInput, NodeOutput, ) @@ -14,7 +14,7 @@ class V3TestNode(ComfyNodeV3): @classmethod def DEFINE_SCHEMA(cls): - schema = SchemaV3( + return SchemaV3( node_id="V3TestNode1", display_name="V3 Test Node (1djekjd)", description="This is a funky V3 node test.", @@ -36,10 +36,17 @@ class V3TestNode(ComfyNodeV3): ], is_output_node=True, ) - return schema def execute(self, some_int: int, image: torch.Tensor, mask: torch.Tensor=None, **kwargs): - return (None,) + a = NodeOutput(1) + aa = NodeOutput(1, "hellothere") + ab = NodeOutput(1, "hellothere", ui={"lol": "jk"}) + b = NodeOutput() + c = NodeOutput(ui={"lol": "jk"}) + return NodeOutput() + return NodeOutput(1) + return NodeOutput(1, block_execution="Kill yourself") + return () diff --git a/execution.py b/execution.py index 15ff7567..64707c28 100644 --- a/execution.py +++ b/execution.py @@ -17,6 +17,7 @@ from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID from comfy_execution.validation import validate_node_input +from comfy_api.v3.io import NodeOutput class ExecutionResult(Enum): SUCCESS = 0 @@ -242,6 +243,22 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb result = tuple([result] * len(obj.RETURN_TYPES)) results.append(result) subgraph_results.append((None, result)) + elif isinstance(r, NodeOutput): + if r.ui is not None: + uis.append(r.ui.as_dict()) + if r.expand is not None: + has_subgraph = True + new_graph = r.expand + result = r.result + if r.block_execution is not None: + result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES)) + subgraph_results.append((new_graph, result)) + elif r.result is not None: + result = r.result + if r.block_execution is not None: + result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES)) + results.append(result) + subgraph_results.append((None, result)) else: if isinstance(r, ExecutionBlocker): r = tuple([r] * len(obj.RETURN_TYPES))