mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-09 23:57:14 +00:00
Created and handled NodeOutput class to be the return value of v3 nodes' execute function
This commit is contained in:
parent
8642757971
commit
0d185b721f
@ -1,13 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Union, Any
|
from typing import Any, Literal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
|
from comfy.comfy_types.node_typing import IO
|
||||||
|
|
||||||
|
|
||||||
class InputBehavior(str, Enum):
|
class InputBehavior(str, Enum):
|
||||||
required = "required"
|
required = "required"
|
||||||
optional = "optional"
|
optional = "optional"
|
||||||
# TODO: handle hidden inputs
|
|
||||||
|
|
||||||
|
|
||||||
def is_class(obj):
|
def is_class(obj):
|
||||||
@ -30,7 +31,7 @@ class IO_V3:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __init_subclass__(cls, io_type, **kwargs):
|
def __init_subclass__(cls, io_type: IO | str, **kwargs):
|
||||||
cls.io_type = io_type
|
cls.io_type = io_type
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
@ -75,11 +76,11 @@ class WidgetInputV3(InputV3, io_type=None):
|
|||||||
"widgetType": self.widgetType,
|
"widgetType": self.widgetType,
|
||||||
})
|
})
|
||||||
|
|
||||||
def CustomType(io_type: str) -> type[IO_V3]:
|
def CustomType(io_type: IO | str) -> type[IO_V3]:
|
||||||
name = f"{io_type}_IO_V3"
|
name = f"{io_type}_IO_V3"
|
||||||
return type(name, (IO_V3,), {}, io_type=io_type)
|
return type(name, (IO_V3,), {}, io_type=io_type)
|
||||||
|
|
||||||
def CustomInput(id: str, io_type: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None) -> InputV3:
|
def CustomInput(id: str, io_type: IO | str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None) -> InputV3:
|
||||||
'''
|
'''
|
||||||
Defines input for 'io_type'. Can be used to stand in for non-core types.
|
Defines input for 'io_type'. Can be used to stand in for non-core types.
|
||||||
'''
|
'''
|
||||||
@ -92,7 +93,7 @@ def CustomInput(id: str, io_type: str, display_name: str=None, behavior=InputBeh
|
|||||||
}
|
}
|
||||||
return type(f"{io_type}Input", (InputV3,), {}, io_type=io_type)(**input_kwargs)
|
return type(f"{io_type}Input", (InputV3,), {}, io_type=io_type)(**input_kwargs)
|
||||||
|
|
||||||
def CustomOutput(id: str, io_type: str, display_name: str=None, tooltip: str=None) -> OutputV3:
|
def CustomOutput(id: str, io_type: IO | str, display_name: str=None, tooltip: str=None) -> OutputV3:
|
||||||
'''
|
'''
|
||||||
Defines output for 'io_type'. Can be used to stand in for non-core types.
|
Defines output for 'io_type'. Can be used to stand in for non-core types.
|
||||||
'''
|
'''
|
||||||
@ -104,7 +105,7 @@ def CustomOutput(id: str, io_type: str, display_name: str=None, tooltip: str=Non
|
|||||||
return type(f"{io_type}Output", (OutputV3,), {}, io_type=io_type)(**input_kwargs)
|
return type(f"{io_type}Output", (OutputV3,), {}, io_type=io_type)(**input_kwargs)
|
||||||
|
|
||||||
|
|
||||||
class BooleanInput(WidgetInputV3, io_type="BOOLEAN"):
|
class BooleanInput(WidgetInputV3, io_type=IO.BOOLEAN):
|
||||||
'''
|
'''
|
||||||
Boolean input.
|
Boolean input.
|
||||||
'''
|
'''
|
||||||
@ -122,7 +123,7 @@ class BooleanInput(WidgetInputV3, io_type="BOOLEAN"):
|
|||||||
"label_off": self.label_off,
|
"label_off": self.label_off,
|
||||||
})
|
})
|
||||||
|
|
||||||
class IntegerInput(WidgetInputV3, io_type="INT"):
|
class IntegerInput(WidgetInputV3, io_type=IO.INT):
|
||||||
'''
|
'''
|
||||||
Integer input.
|
Integer input.
|
||||||
'''
|
'''
|
||||||
@ -146,7 +147,7 @@ class IntegerInput(WidgetInputV3, io_type="INT"):
|
|||||||
"display": self.display_mode, # NOTE: in frontend, the parameter is called "display"
|
"display": self.display_mode, # NOTE: in frontend, the parameter is called "display"
|
||||||
})
|
})
|
||||||
|
|
||||||
class FloatInput(WidgetInputV3, io_type="FLOAT"):
|
class FloatInput(WidgetInputV3, io_type=IO.FLOAT):
|
||||||
'''
|
'''
|
||||||
Float input.
|
Float input.
|
||||||
'''
|
'''
|
||||||
@ -171,7 +172,7 @@ class FloatInput(WidgetInputV3, io_type="FLOAT"):
|
|||||||
"display": self.display_mode, # NOTE: in frontend, the parameter is called "display"
|
"display": self.display_mode, # NOTE: in frontend, the parameter is called "display"
|
||||||
})
|
})
|
||||||
|
|
||||||
class StringInput(WidgetInputV3, io_type="STRING"):
|
class StringInput(WidgetInputV3, io_type=IO.STRING):
|
||||||
'''
|
'''
|
||||||
String input.
|
String input.
|
||||||
'''
|
'''
|
||||||
@ -189,7 +190,7 @@ class StringInput(WidgetInputV3, io_type="STRING"):
|
|||||||
"placeholder": self.placeholder,
|
"placeholder": self.placeholder,
|
||||||
})
|
})
|
||||||
|
|
||||||
class ComboInput(WidgetInputV3, io_type="COMBO"):
|
class ComboInput(WidgetInputV3, io_type=IO.COMBO):
|
||||||
'''Combo input (dropdown).'''
|
'''Combo input (dropdown).'''
|
||||||
def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||||
default: str=None, control_after_generate: bool=None,
|
default: str=None, control_after_generate: bool=None,
|
||||||
@ -207,7 +208,7 @@ class ComboInput(WidgetInputV3, io_type="COMBO"):
|
|||||||
"control_after_generate": self.control_after_generate,
|
"control_after_generate": self.control_after_generate,
|
||||||
})
|
})
|
||||||
|
|
||||||
class MultiselectComboWidget(ComboInput, io_type="COMBO"):
|
class MultiselectComboWidget(ComboInput, io_type=IO.COMBO):
|
||||||
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
|
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
|
||||||
def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
|
||||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
||||||
@ -225,21 +226,21 @@ class MultiselectComboWidget(ComboInput, io_type="COMBO"):
|
|||||||
"chip": self.chip,
|
"chip": self.chip,
|
||||||
})
|
})
|
||||||
|
|
||||||
class ImageInput(InputV3, io_type="IMAGE"):
|
class ImageInput(InputV3, io_type=IO.IMAGE):
|
||||||
'''
|
'''
|
||||||
Image input.
|
Image input.
|
||||||
'''
|
'''
|
||||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
|
||||||
super().__init__(id, display_name, behavior, tooltip)
|
super().__init__(id, display_name, behavior, tooltip)
|
||||||
|
|
||||||
class MaskInput(InputV3, io_type="MASK"):
|
class MaskInput(InputV3, io_type=IO.MASK):
|
||||||
'''
|
'''
|
||||||
Mask input.
|
Mask input.
|
||||||
'''
|
'''
|
||||||
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
|
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
|
||||||
super().__init__(id, display_name, behavior, tooltip)
|
super().__init__(id, display_name, behavior, tooltip)
|
||||||
|
|
||||||
class LatentInput(InputV3, io_type="LATENT"):
|
class LatentInput(InputV3, io_type=IO.LATENT):
|
||||||
'''
|
'''
|
||||||
Latent input.
|
Latent input.
|
||||||
'''
|
'''
|
||||||
@ -250,7 +251,7 @@ class MultitypedInput(InputV3, io_type="COMFY_MULTITYPED_V3"):
|
|||||||
'''
|
'''
|
||||||
Input that permits more than one input type.
|
Input that permits more than one input type.
|
||||||
'''
|
'''
|
||||||
def __init__(self, id: str, io_types: list[Union[type[IO_V3], InputV3, str]], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None,):
|
def __init__(self, id: str, io_types: list[type[IO_V3] | InputV3 | IO |str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None,):
|
||||||
super().__init__(id, display_name, behavior, tooltip)
|
super().__init__(id, display_name, behavior, tooltip)
|
||||||
self._io_types = io_types
|
self._io_types = io_types
|
||||||
|
|
||||||
@ -283,24 +284,24 @@ class OutputV3:
|
|||||||
cls.io_type = io_type
|
cls.io_type = io_type
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
class IntegerOutput(OutputV3, io_type="INT"):
|
class IntegerOutput(OutputV3, io_type=IO.INT):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class FloatOutput(OutputV3, io_type="FLOAT"):
|
class FloatOutput(OutputV3, io_type=IO.FLOAT):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class StringOutput(OutputV3, io_type="STRING"):
|
class StringOutput(OutputV3, io_type=IO.STRING):
|
||||||
pass
|
pass
|
||||||
# def __init__(self, id: str, display_name: str=None, tooltip: str=None):
|
# def __init__(self, id: str, display_name: str=None, tooltip: str=None):
|
||||||
# super().__init__(id, display_name, tooltip)
|
# super().__init__(id, display_name, tooltip)
|
||||||
|
|
||||||
class ImageOutput(OutputV3, io_type="IMAGE"):
|
class ImageOutput(OutputV3, io_type=IO.IMAGE):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class MaskOutput(OutputV3, io_type="MASK"):
|
class MaskOutput(OutputV3, io_type=IO.MASK):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class LatentOutput(OutputV3, io_type="LATENT"):
|
class LatentOutput(OutputV3, io_type=IO.LATENT):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -675,6 +676,12 @@ class ComfyNodeV3(ABC):
|
|||||||
#--------------------------------------------
|
#--------------------------------------------
|
||||||
#############################################
|
#############################################
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
|
||||||
|
schema = cls.GET_SCHEMA()
|
||||||
|
# TODO: finish
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -690,26 +697,103 @@ class ComfyNodeV3(ABC):
|
|||||||
raise Exception("No DEFINE_SCHEMA function was defined for this node.")
|
raise Exception("No DEFINE_SCHEMA function was defined for this node.")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute(self, inputs, outputs, hidden, **kwargs):
|
def execute(self, **kwargs) -> NodeOutput:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ReturnedInputs:
|
# class ReturnedInputs:
|
||||||
|
# def __init__(self):
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# class ReturnedOutputs:
|
||||||
|
# def __init__(self):
|
||||||
|
# pass
|
||||||
|
|
||||||
|
|
||||||
|
class NodeOutput:
|
||||||
|
'''
|
||||||
|
Standardized output of a node; can pass in any number of args and/or a UIOutput into 'ui' kwarg.
|
||||||
|
'''
|
||||||
|
def __init__(self, *args: Any, ui: UIOutput | dict=None, expand: dict=None, block_execution: str=None, **kwargs):
|
||||||
|
self.args = args
|
||||||
|
self.ui = ui
|
||||||
|
self.expand = expand
|
||||||
|
self.block_execution = block_execution
|
||||||
|
|
||||||
|
@property
|
||||||
|
def result(self):
|
||||||
|
return self.args if len(self.args) > 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
class SavedResult:
|
||||||
|
def __init__(self, filename: str, subfolder: str, type: Literal["input", "output", "temp"]):
|
||||||
|
self.filename = filename
|
||||||
|
self.subfolder = subfolder
|
||||||
|
self.type = type
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return {
|
||||||
|
"filename": self.filename,
|
||||||
|
"subfolder": self.subfolder,
|
||||||
|
"type": self.type
|
||||||
|
}
|
||||||
|
|
||||||
|
class UIOutput(ABC):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class ReturnedOutputs:
|
@abstractmethod
|
||||||
def __init__(self):
|
def as_dict(self) -> dict:
|
||||||
pass
|
... # TODO: finish
|
||||||
|
|
||||||
|
class UIImages(UIOutput):
|
||||||
|
def __init__(self, values: list[SavedResult | dict], animated=False, **kwargs):
|
||||||
|
self.values = values
|
||||||
|
self.animated = animated
|
||||||
|
|
||||||
class NodeOutputV3:
|
def as_dict(self):
|
||||||
def __init__(self):
|
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
|
||||||
pass
|
return {
|
||||||
|
"images": values,
|
||||||
|
"animated": (self.animated,)
|
||||||
|
}
|
||||||
|
|
||||||
class UINodeOutput:
|
class UILatents(UIOutput):
|
||||||
def __init__(self):
|
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
||||||
pass
|
self.values = values
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
|
||||||
|
return {
|
||||||
|
"latents": values,
|
||||||
|
}
|
||||||
|
|
||||||
|
class UIAudio(UIOutput):
|
||||||
|
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
||||||
|
self.values = values
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
|
||||||
|
return {
|
||||||
|
"audio": values,
|
||||||
|
}
|
||||||
|
|
||||||
|
class UI3D(UIOutput):
|
||||||
|
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
||||||
|
self.values = values
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
values = [x.as_dict() if isinstance(x, SavedResult) else x for x in self.values]
|
||||||
|
return {
|
||||||
|
"3d": values,
|
||||||
|
}
|
||||||
|
|
||||||
|
class UIText(UIOutput):
|
||||||
|
def __init__(self, value: str, **kwargs):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return {"text": (self.value,)}
|
||||||
|
|
||||||
|
|
||||||
class TestNode(ComfyNodeV3):
|
class TestNode(ComfyNodeV3):
|
||||||
|
@ -2,7 +2,7 @@ import torch
|
|||||||
|
|
||||||
from comfy_api.v3.io import (
|
from comfy_api.v3.io import (
|
||||||
ComfyNodeV3, SchemaV3, CustomType, CustomInput, CustomOutput, InputBehavior, NumberDisplay,
|
ComfyNodeV3, SchemaV3, CustomType, CustomInput, CustomOutput, InputBehavior, NumberDisplay,
|
||||||
IntegerInput, MaskInput, ImageInput, ComboDynamicInput,
|
IntegerInput, MaskInput, ImageInput, ComboDynamicInput, NodeOutput,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ class V3TestNode(ComfyNodeV3):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def DEFINE_SCHEMA(cls):
|
def DEFINE_SCHEMA(cls):
|
||||||
schema = SchemaV3(
|
return SchemaV3(
|
||||||
node_id="V3TestNode1",
|
node_id="V3TestNode1",
|
||||||
display_name="V3 Test Node (1djekjd)",
|
display_name="V3 Test Node (1djekjd)",
|
||||||
description="This is a funky V3 node test.",
|
description="This is a funky V3 node test.",
|
||||||
@ -36,10 +36,17 @@ class V3TestNode(ComfyNodeV3):
|
|||||||
],
|
],
|
||||||
is_output_node=True,
|
is_output_node=True,
|
||||||
)
|
)
|
||||||
return schema
|
|
||||||
|
|
||||||
def execute(self, some_int: int, image: torch.Tensor, mask: torch.Tensor=None, **kwargs):
|
def execute(self, some_int: int, image: torch.Tensor, mask: torch.Tensor=None, **kwargs):
|
||||||
return (None,)
|
a = NodeOutput(1)
|
||||||
|
aa = NodeOutput(1, "hellothere")
|
||||||
|
ab = NodeOutput(1, "hellothere", ui={"lol": "jk"})
|
||||||
|
b = NodeOutput()
|
||||||
|
c = NodeOutput(ui={"lol": "jk"})
|
||||||
|
return NodeOutput()
|
||||||
|
return NodeOutput(1)
|
||||||
|
return NodeOutput(1, block_execution="Kill yourself")
|
||||||
|
return ()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
17
execution.py
17
execution.py
@ -17,6 +17,7 @@ from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt,
|
|||||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||||
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
|
from comfy_api.v3.io import NodeOutput
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
SUCCESS = 0
|
SUCCESS = 0
|
||||||
@ -242,6 +243,22 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
|
|||||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||||
results.append(result)
|
results.append(result)
|
||||||
subgraph_results.append((None, result))
|
subgraph_results.append((None, result))
|
||||||
|
elif isinstance(r, NodeOutput):
|
||||||
|
if r.ui is not None:
|
||||||
|
uis.append(r.ui.as_dict())
|
||||||
|
if r.expand is not None:
|
||||||
|
has_subgraph = True
|
||||||
|
new_graph = r.expand
|
||||||
|
result = r.result
|
||||||
|
if r.block_execution is not None:
|
||||||
|
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||||
|
subgraph_results.append((new_graph, result))
|
||||||
|
elif r.result is not None:
|
||||||
|
result = r.result
|
||||||
|
if r.block_execution is not None:
|
||||||
|
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||||
|
results.append(result)
|
||||||
|
subgraph_results.append((None, result))
|
||||||
else:
|
else:
|
||||||
if isinstance(r, ExecutionBlocker):
|
if isinstance(r, ExecutionBlocker):
|
||||||
r = tuple([r] * len(obj.RETURN_TYPES))
|
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user