mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 16:26:39 +00:00
Changed execute instance method to EXECUTE class method, added countermeasures to avoid state leaks, ready ability to add extra params to clean class type clone
This commit is contained in:
parent
a7f515e913
commit
d79a3cf990
@ -532,6 +532,30 @@ class SchemaV3:
|
|||||||
# """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
# """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||||
|
|
||||||
|
|
||||||
|
class Serializer:
|
||||||
|
def __init_subclass__(cls, io_type: IO | str, **kwargs):
|
||||||
|
cls.io_type = io_type
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def serialize(cls, o: Any) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize(cls, s: str) -> Any:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_class_clone(c: ComfyNodeV3 | type[ComfyNodeV3]) -> type[ComfyNodeV3]:
|
||||||
|
"""Creates clone of real node class to prevent monkey-patching."""
|
||||||
|
c_type: type[ComfyNodeV3] = c if is_class(c) else type(c)
|
||||||
|
type_clone: type[ComfyNodeV3] = type(f"CLEAN_{c_type.__name__}", c_type.__bases__, {})
|
||||||
|
# TODO: what parameters should be carried over?
|
||||||
|
type_clone.SCHEMA = c_type.SCHEMA
|
||||||
|
# TODO: add anything we would want to expose inside node's EXECUTE function
|
||||||
|
return type_clone
|
||||||
|
|
||||||
|
|
||||||
class classproperty(object):
|
class classproperty(object):
|
||||||
def __init__(self, f):
|
def __init__(self, f):
|
||||||
self.f = f
|
self.f = f
|
||||||
@ -543,6 +567,43 @@ class ComfyNodeV3(ABC):
|
|||||||
"""Common base class for all V3 nodes."""
|
"""Common base class for all V3 nodes."""
|
||||||
|
|
||||||
RELATIVE_PYTHON_MODULE = None
|
RELATIVE_PYTHON_MODULE = None
|
||||||
|
SCHEMA = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
|
||||||
|
schema = cls.GET_SCHEMA()
|
||||||
|
# TODO: finish
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def DEFINE_SCHEMA(cls) -> SchemaV3:
|
||||||
|
"""
|
||||||
|
Override this function with one that returns a SchemaV3 instance.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
DEFINE_SCHEMA = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def EXECUTE(cls, **kwargs) -> NodeOutput:
|
||||||
|
pass
|
||||||
|
EXECUTE = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def GET_SERIALIZERS(cls) -> list[Serializer]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.__class__.VALIDATE_CLASS()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_CLASS(cls):
|
||||||
|
if not callable(cls.DEFINE_SCHEMA):
|
||||||
|
raise Exception(f"No DEFINE_SCHEMA function was defined for node class {cls.__name__}.")
|
||||||
|
if not callable(cls.EXECUTE):
|
||||||
|
raise Exception(f"No execute function was defined for node class {cls.__name__}.")
|
||||||
|
|
||||||
#############################################
|
#############################################
|
||||||
# V1 Backwards Compatibility code
|
# V1 Backwards Compatibility code
|
||||||
#--------------------------------------------
|
#--------------------------------------------
|
||||||
@ -623,7 +684,7 @@ class ComfyNodeV3(ABC):
|
|||||||
cls.GET_SCHEMA()
|
cls.GET_SCHEMA()
|
||||||
return cls._OUTPUT_TOOLTIPS
|
return cls._OUTPUT_TOOLTIPS
|
||||||
|
|
||||||
FUNCTION = "execute"
|
FUNCTION = "EXECUTE"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> dict[str, dict]:
|
def INPUT_TYPES(cls) -> dict[str, dict]:
|
||||||
@ -642,6 +703,7 @@ class ComfyNodeV3(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def GET_SCHEMA(cls) -> SchemaV3:
|
def GET_SCHEMA(cls) -> SchemaV3:
|
||||||
|
cls.VALIDATE_CLASS()
|
||||||
schema = cls.DEFINE_SCHEMA()
|
schema = cls.DEFINE_SCHEMA()
|
||||||
if cls._DESCRIPTION is None:
|
if cls._DESCRIPTION is None:
|
||||||
cls._DESCRIPTION = schema.description
|
cls._DESCRIPTION = schema.description
|
||||||
@ -674,7 +736,7 @@ class ComfyNodeV3(ABC):
|
|||||||
cls._RETURN_NAMES = output_name
|
cls._RETURN_NAMES = output_name
|
||||||
cls._OUTPUT_IS_LIST = output_is_list
|
cls._OUTPUT_IS_LIST = output_is_list
|
||||||
cls._OUTPUT_TOOLTIPS = output_tooltips
|
cls._OUTPUT_TOOLTIPS = output_tooltips
|
||||||
|
cls.SCHEMA = schema
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -716,31 +778,6 @@ class ComfyNodeV3(ABC):
|
|||||||
#--------------------------------------------
|
#--------------------------------------------
|
||||||
#############################################
|
#############################################
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
|
|
||||||
schema = cls.GET_SCHEMA()
|
|
||||||
# TODO: finish
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@abstractmethod
|
|
||||||
def DEFINE_SCHEMA(cls) -> SchemaV3:
|
|
||||||
"""
|
|
||||||
Override this function with one that returns a SchemaV3 instance.
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
DEFINE_SCHEMA = None
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
if self.DEFINE_SCHEMA is None:
|
|
||||||
raise Exception("No DEFINE_SCHEMA function was defined for this node.")
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def execute(self, **kwargs) -> NodeOutput:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# class ReturnedInputs:
|
# class ReturnedInputs:
|
||||||
# def __init__(self):
|
# def __init__(self):
|
||||||
# pass
|
# pass
|
||||||
@ -857,19 +894,20 @@ class TestNode(ComfyNodeV3):
|
|||||||
def DEFINE_SCHEMA(cls):
|
def DEFINE_SCHEMA(cls):
|
||||||
return cls.SCHEMA
|
return cls.SCHEMA
|
||||||
|
|
||||||
def execute(**kwargs):
|
def EXECUTE(**kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("hello there")
|
print("hello there")
|
||||||
inputs: list[InputV3] = [
|
inputs: list[InputV3] = [
|
||||||
|
IntegerInput("tessfes", widgetType=IO.STRING),
|
||||||
IntegerInput("my_int"),
|
IntegerInput("my_int"),
|
||||||
CustomInput("xyz", "XYZ"),
|
CustomInput("xyz", "XYZ"),
|
||||||
CustomInput("model1", "MODEL_M"),
|
CustomInput("model1", "MODEL_M"),
|
||||||
ImageInput("my_image"),
|
ImageInput("my_image"),
|
||||||
FloatInput("my_float"),
|
FloatInput("my_float"),
|
||||||
MultitypedInput("my_inputs", [CustomType("MODEL_M"), CustomType("XYZ")]),
|
MultitypedInput("my_inputs", [StringInput, CustomType("MODEL_M"), CustomType("XYZ")]),
|
||||||
]
|
]
|
||||||
|
|
||||||
outputs: list[OutputV3] = [
|
outputs: list[OutputV3] = [
|
||||||
|
@ -1,13 +1,18 @@
|
|||||||
import torch
|
import torch
|
||||||
from comfy_api.v3.io import (
|
from comfy_api.v3.io import (
|
||||||
ComfyNodeV3, SchemaV3, InputBehavior, NumberDisplay,
|
ComfyNodeV3, SchemaV3, InputBehavior, NumberDisplay,
|
||||||
IntegerInput, MaskInput, ImageInput, ComboInput, CustomInput,
|
IntegerInput, MaskInput, ImageInput, ComboInput, CustomInput, StringInput, CustomType,
|
||||||
IntegerOutput, ImageOutput,
|
IntegerOutput, ImageOutput, MultitypedInput,
|
||||||
NodeOutput,
|
NodeOutput, Hidden
|
||||||
)
|
)
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
class V3TestNode(ComfyNodeV3):
|
class V3TestNode(ComfyNodeV3):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.hahajkunless = ";)"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def DEFINE_SCHEMA(cls):
|
def DEFINE_SCHEMA(cls):
|
||||||
return SchemaV3(
|
return SchemaV3(
|
||||||
@ -17,7 +22,7 @@ class V3TestNode(ComfyNodeV3):
|
|||||||
category="v3 nodes",
|
category="v3 nodes",
|
||||||
inputs=[
|
inputs=[
|
||||||
ImageInput("image", display_name="new_image"),
|
ImageInput("image", display_name="new_image"),
|
||||||
CustomInput("xyz", "XYZ"),
|
CustomInput("xyz", "XYZ", behavior=InputBehavior.optional),
|
||||||
MaskInput("mask", behavior=InputBehavior.optional),
|
MaskInput("mask", behavior=InputBehavior.optional),
|
||||||
IntegerInput("some_int", display_name="new_name", min=0, max=127, default=42,
|
IntegerInput("some_int", display_name="new_name", min=0, max=127, default=42,
|
||||||
tooltip="My tooltip 😎", display_mode=NumberDisplay.slider),
|
tooltip="My tooltip 😎", display_mode=NumberDisplay.slider),
|
||||||
@ -42,11 +47,20 @@ class V3TestNode(ComfyNodeV3):
|
|||||||
outputs=[
|
outputs=[
|
||||||
IntegerOutput("int_output"),
|
IntegerOutput("int_output"),
|
||||||
ImageOutput("img_output", display_name="img🖼️", tooltip="This is an image"),
|
ImageOutput("img_output", display_name="img🖼️", tooltip="This is an image"),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
|
||||||
],
|
],
|
||||||
is_output_node=True,
|
is_output_node=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def execute(self, image: torch.Tensor, xyz, some_int: int, combo: str, mask: torch.Tensor=None):
|
@classmethod
|
||||||
|
def EXECUTE(cls, image: torch.Tensor, some_int: int, combo: str, xyz=None, mask: torch.Tensor=None):
|
||||||
|
if hasattr(cls, "hahajkunless"):
|
||||||
|
raise Exception("The 'cls' variable leaked instance state between runs!")
|
||||||
|
if hasattr(cls, "doohickey"):
|
||||||
|
raise Exception("The 'cls' variable leaked state on class properties between runs!")
|
||||||
|
cls.doohickey = "LOLJK"
|
||||||
return NodeOutput(some_int, image)
|
return NodeOutput(some_int, image)
|
||||||
|
|
||||||
|
|
||||||
|
11
execution.py
11
execution.py
@ -17,7 +17,7 @@ from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt,
|
|||||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||||
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
from comfy_api.v3.io import NodeOutput
|
from comfy_api.v3.io import NodeOutput, ComfyNodeV3, prepare_class_clone
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
SUCCESS = 0
|
SUCCESS = 0
|
||||||
@ -183,7 +183,14 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
|||||||
if execution_block is None:
|
if execution_block is None:
|
||||||
if pre_execute_cb is not None and index is not None:
|
if pre_execute_cb is not None and index is not None:
|
||||||
pre_execute_cb(index)
|
pre_execute_cb(index)
|
||||||
results.append(getattr(obj, func)(**inputs))
|
# V3
|
||||||
|
if isinstance(obj, ComfyNodeV3):
|
||||||
|
type(obj).VALIDATE_CLASS()
|
||||||
|
class_clone = prepare_class_clone(obj)
|
||||||
|
results.append(type(obj).EXECUTE.__func__(class_clone, **inputs))
|
||||||
|
# V1
|
||||||
|
else:
|
||||||
|
results.append(getattr(obj, func)(**inputs))
|
||||||
else:
|
else:
|
||||||
results.append(execution_block)
|
results.append(execution_block)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user