Created and handled NodeOutput class to be the return value of v3 nodes' execute function

This commit is contained in:
Jedrzej Kosinski 2025-06-01 01:08:07 -07:00
parent 8642757971
commit 0d185b721f
3 changed files with 145 additions and 37 deletions

View File

@ -1,13 +1,14 @@
from __future__ import annotations from __future__ import annotations
from typing import Union, Any from typing import Any, Literal
from enum import Enum from enum import Enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from comfy.comfy_types.node_typing import IO
class InputBehavior(str, Enum): class InputBehavior(str, Enum):
required = "required" required = "required"
optional = "optional" optional = "optional"
# TODO: handle hidden inputs
def is_class(obj): def is_class(obj):
@ -30,7 +31,7 @@ class IO_V3:
def __init__(self): def __init__(self):
pass pass
def __init_subclass__(cls, io_type, **kwargs): def __init_subclass__(cls, io_type: IO | str, **kwargs):
cls.io_type = io_type cls.io_type = io_type
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
@ -75,11 +76,11 @@ class WidgetInputV3(InputV3, io_type=None):
"widgetType": self.widgetType, "widgetType": self.widgetType,
}) })
def CustomType(io_type: str) -> type[IO_V3]: def CustomType(io_type: IO | str) -> type[IO_V3]:
name = f"{io_type}_IO_V3" name = f"{io_type}_IO_V3"
return type(name, (IO_V3,), {}, io_type=io_type) return type(name, (IO_V3,), {}, io_type=io_type)
def CustomInput(id: str, io_type: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None) -> InputV3: def CustomInput(id: str, io_type: IO | str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None) -> InputV3:
''' '''
Defines input for 'io_type'. Can be used to stand in for non-core types. Defines input for 'io_type'. Can be used to stand in for non-core types.
''' '''
@ -92,7 +93,7 @@ def CustomInput(id: str, io_type: str, display_name: str=None, behavior=InputBeh
} }
return type(f"{io_type}Input", (InputV3,), {}, io_type=io_type)(**input_kwargs) return type(f"{io_type}Input", (InputV3,), {}, io_type=io_type)(**input_kwargs)
def CustomOutput(id: str, io_type: str, display_name: str=None, tooltip: str=None) -> OutputV3: def CustomOutput(id: str, io_type: IO | str, display_name: str=None, tooltip: str=None) -> OutputV3:
''' '''
Defines output for 'io_type'. Can be used to stand in for non-core types. Defines output for 'io_type'. Can be used to stand in for non-core types.
''' '''
@ -104,7 +105,7 @@ def CustomOutput(id: str, io_type: str, display_name: str=None, tooltip: str=Non
return type(f"{io_type}Output", (OutputV3,), {}, io_type=io_type)(**input_kwargs) return type(f"{io_type}Output", (OutputV3,), {}, io_type=io_type)(**input_kwargs)
class BooleanInput(WidgetInputV3, io_type="BOOLEAN"): class BooleanInput(WidgetInputV3, io_type=IO.BOOLEAN):
''' '''
Boolean input. Boolean input.
''' '''
@ -122,7 +123,7 @@ class BooleanInput(WidgetInputV3, io_type="BOOLEAN"):
"label_off": self.label_off, "label_off": self.label_off,
}) })
class IntegerInput(WidgetInputV3, io_type="INT"): class IntegerInput(WidgetInputV3, io_type=IO.INT):
''' '''
Integer input. Integer input.
''' '''
@ -146,7 +147,7 @@ class IntegerInput(WidgetInputV3, io_type="INT"):
"display": self.display_mode, # NOTE: in frontend, the parameter is called "display" "display": self.display_mode, # NOTE: in frontend, the parameter is called "display"
}) })
class FloatInput(WidgetInputV3, io_type="FLOAT"): class FloatInput(WidgetInputV3, io_type=IO.FLOAT):
''' '''
Float input. Float input.
''' '''
@ -171,7 +172,7 @@ class FloatInput(WidgetInputV3, io_type="FLOAT"):
"display": self.display_mode, # NOTE: in frontend, the parameter is called "display" "display": self.display_mode, # NOTE: in frontend, the parameter is called "display"
}) })
class StringInput(WidgetInputV3, io_type="STRING"): class StringInput(WidgetInputV3, io_type=IO.STRING):
''' '''
String input. String input.
''' '''
@ -189,7 +190,7 @@ class StringInput(WidgetInputV3, io_type="STRING"):
"placeholder": self.placeholder, "placeholder": self.placeholder,
}) })
class ComboInput(WidgetInputV3, io_type="COMBO"): class ComboInput(WidgetInputV3, io_type=IO.COMBO):
'''Combo input (dropdown).''' '''Combo input (dropdown).'''
def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: str=None, control_after_generate: bool=None, default: str=None, control_after_generate: bool=None,
@ -207,7 +208,7 @@ class ComboInput(WidgetInputV3, io_type="COMBO"):
"control_after_generate": self.control_after_generate, "control_after_generate": self.control_after_generate,
}) })
class MultiselectComboWidget(ComboInput, io_type="COMBO"): class MultiselectComboWidget(ComboInput, io_type=IO.COMBO):
'''Multiselect Combo input (dropdown for selecting potentially more than one value).''' '''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
@ -225,21 +226,21 @@ class MultiselectComboWidget(ComboInput, io_type="COMBO"):
"chip": self.chip, "chip": self.chip,
}) })
class ImageInput(InputV3, io_type="IMAGE"): class ImageInput(InputV3, io_type=IO.IMAGE):
''' '''
Image input. Image input.
''' '''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None): def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
super().__init__(id, display_name, behavior, tooltip) super().__init__(id, display_name, behavior, tooltip)
class MaskInput(InputV3, io_type="MASK"): class MaskInput(InputV3, io_type=IO.MASK):
''' '''
Mask input. Mask input.
''' '''
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None): def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
super().__init__(id, display_name, behavior, tooltip) super().__init__(id, display_name, behavior, tooltip)
class LatentInput(InputV3, io_type="LATENT"): class LatentInput(InputV3, io_type=IO.LATENT):
''' '''
Latent input. Latent input.
''' '''
@ -250,7 +251,7 @@ class MultitypedInput(InputV3, io_type="COMFY_MULTITYPED_V3"):
''' '''
Input that permits more than one input type. Input that permits more than one input type.
''' '''
def __init__(self, id: str, io_types: list[Union[type[IO_V3], InputV3, str]], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None,): def __init__(self, id: str, io_types: list[type[IO_V3] | InputV3 | IO |str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None,):
super().__init__(id, display_name, behavior, tooltip) super().__init__(id, display_name, behavior, tooltip)
self._io_types = io_types self._io_types = io_types
@ -283,24 +284,24 @@ class OutputV3:
cls.io_type = io_type cls.io_type = io_type
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
class IntegerOutput(OutputV3, io_type="INT"): class IntegerOutput(OutputV3, io_type=IO.INT):
pass pass
class FloatOutput(OutputV3, io_type="FLOAT"): class FloatOutput(OutputV3, io_type=IO.FLOAT):
pass pass
class StringOutput(OutputV3, io_type="STRING"): class StringOutput(OutputV3, io_type=IO.STRING):
pass pass
# def __init__(self, id: str, display_name: str=None, tooltip: str=None): # def __init__(self, id: str, display_name: str=None, tooltip: str=None):
# super().__init__(id, display_name, tooltip) # super().__init__(id, display_name, tooltip)
class ImageOutput(OutputV3, io_type="IMAGE"): class ImageOutput(OutputV3, io_type=IO.IMAGE):
pass pass
class MaskOutput(OutputV3, io_type="MASK"): class MaskOutput(OutputV3, io_type=IO.MASK):
pass pass
class LatentOutput(OutputV3, io_type="LATENT"): class LatentOutput(OutputV3, io_type=IO.LATENT):
pass pass
@ -675,6 +676,12 @@ class ComfyNodeV3(ABC):
#-------------------------------------------- #--------------------------------------------
############################################# #############################################
@classmethod
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
schema = cls.GET_SCHEMA()
# TODO: finish
return None
@classmethod @classmethod
@abstractmethod @abstractmethod
@ -690,26 +697,103 @@ class ComfyNodeV3(ABC):
raise Exception("No DEFINE_SCHEMA function was defined for this node.") raise Exception("No DEFINE_SCHEMA function was defined for this node.")
@abstractmethod @abstractmethod
def execute(self, inputs, outputs, hidden, **kwargs): def execute(self, **kwargs) -> NodeOutput:
pass pass
class ReturnedInputs: # class ReturnedInputs:
# def __init__(self):
# pass
# class ReturnedOutputs:
# def __init__(self):
# pass
class NodeOutput:
'''
Standardized output of a node; can pass in any number of args and/or a UIOutput into 'ui' kwarg.
'''
def __init__(self, *args: Any, ui: UIOutput | dict=None, expand: dict=None, block_execution: str=None, **kwargs):
self.args = args
self.ui = ui
self.expand = expand
self.block_execution = block_execution
@property
def result(self):
return self.args if len(self.args) > 0 else None
class SavedResult:
def __init__(self, filename: str, subfolder: str, type: Literal["input", "output", "temp"]):
self.filename = filename
self.subfolder = subfolder
self.type = type
def as_dict(self):
return {
"filename": self.filename,
"subfolder": self.subfolder,
"type": self.type
}
class UIOutput(ABC):
def __init__(self): def __init__(self):
pass pass
class ReturnedOutputs: @abstractmethod
def __init__(self): def as_dict(self) -> dict:
pass ... # TODO: finish
class UIImages(UIOutput):
def __init__(self, values: list[SavedResult | dict], animated=False, **kwargs):
self.values = values
self.animated = animated
class NodeOutputV3: def as_dict(self):
def __init__(self): values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
pass return {
"images": values,
"animated": (self.animated,)
}
class UINodeOutput: class UILatents(UIOutput):
def __init__(self): def __init__(self, values: list[SavedResult | dict], **kwargs):
pass self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"latents": values,
}
class UIAudio(UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"audio": values,
}
class UI3D(UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs):
self.values = values
def as_dict(self):
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
return {
"3d": values,
}
class UIText(UIOutput):
def __init__(self, value: str, **kwargs):
self.value = value
def as_dict(self):
return {"text": (self.value,)}
class TestNode(ComfyNodeV3): class TestNode(ComfyNodeV3):

View File

@ -2,7 +2,7 @@ import torch
from comfy_api.v3.io import ( from comfy_api.v3.io import (
ComfyNodeV3, SchemaV3, CustomType, CustomInput, CustomOutput, InputBehavior, NumberDisplay, ComfyNodeV3, SchemaV3, CustomType, CustomInput, CustomOutput, InputBehavior, NumberDisplay,
IntegerInput, MaskInput, ImageInput, ComboDynamicInput, IntegerInput, MaskInput, ImageInput, ComboDynamicInput, NodeOutput,
) )
@ -14,7 +14,7 @@ class V3TestNode(ComfyNodeV3):
@classmethod @classmethod
def DEFINE_SCHEMA(cls): def DEFINE_SCHEMA(cls):
schema = SchemaV3( return SchemaV3(
node_id="V3TestNode1", node_id="V3TestNode1",
display_name="V3 Test Node (1djekjd)", display_name="V3 Test Node (1djekjd)",
description="This is a funky V3 node test.", description="This is a funky V3 node test.",
@ -36,10 +36,17 @@ class V3TestNode(ComfyNodeV3):
], ],
is_output_node=True, is_output_node=True,
) )
return schema
def execute(self, some_int: int, image: torch.Tensor, mask: torch.Tensor=None, **kwargs): def execute(self, some_int: int, image: torch.Tensor, mask: torch.Tensor=None, **kwargs):
return (None,) a = NodeOutput(1)
aa = NodeOutput(1, "hellothere")
ab = NodeOutput(1, "hellothere", ui={"lol": "jk"})
b = NodeOutput()
c = NodeOutput(ui={"lol": "jk"})
return NodeOutput()
return NodeOutput(1)
return NodeOutput(1, block_execution="Kill yourself")
return ()

View File

@ -17,6 +17,7 @@ from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt,
from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.validation import validate_node_input from comfy_execution.validation import validate_node_input
from comfy_api.v3.io import NodeOutput
class ExecutionResult(Enum): class ExecutionResult(Enum):
SUCCESS = 0 SUCCESS = 0
@ -242,6 +243,22 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
result = tuple([result] * len(obj.RETURN_TYPES)) result = tuple([result] * len(obj.RETURN_TYPES))
results.append(result) results.append(result)
subgraph_results.append((None, result)) subgraph_results.append((None, result))
elif isinstance(r, NodeOutput):
if r.ui is not None:
uis.append(r.ui.as_dict())
if r.expand is not None:
has_subgraph = True
new_graph = r.expand
result = r.result
if r.block_execution is not None:
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
subgraph_results.append((new_graph, result))
elif r.result is not None:
result = r.result
if r.block_execution is not None:
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
results.append(result)
subgraph_results.append((None, result))
else: else:
if isinstance(r, ExecutionBlocker): if isinstance(r, ExecutionBlocker):
r = tuple([r] * len(obj.RETURN_TYPES)) r = tuple([r] * len(obj.RETURN_TYPES))