Changed execute instance method to EXECUTE class method, added countermeasures to avoid state leaks, ready ability to add extra params to clean class type clone

This commit is contained in:
Jedrzej Kosinski 2025-06-05 04:12:44 -07:00
parent a7f515e913
commit d79a3cf990
3 changed files with 95 additions and 36 deletions

View File

@ -532,6 +532,30 @@ class SchemaV3:
# """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" # """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
class Serializer:
def __init_subclass__(cls, io_type: IO | str, **kwargs):
cls.io_type = io_type
super().__init_subclass__(**kwargs)
@classmethod
def serialize(cls, o: Any) -> str:
pass
@classmethod
def deserialize(cls, s: str) -> Any:
pass
def prepare_class_clone(c: ComfyNodeV3 | type[ComfyNodeV3]) -> type[ComfyNodeV3]:
"""Creates clone of real node class to prevent monkey-patching."""
c_type: type[ComfyNodeV3] = c if is_class(c) else type(c)
type_clone: type[ComfyNodeV3] = type(f"CLEAN_{c_type.__name__}", c_type.__bases__, {})
# TODO: what parameters should be carried over?
type_clone.SCHEMA = c_type.SCHEMA
# TODO: add anything we would want to expose inside node's EXECUTE function
return type_clone
class classproperty(object): class classproperty(object):
def __init__(self, f): def __init__(self, f):
self.f = f self.f = f
@ -543,6 +567,43 @@ class ComfyNodeV3(ABC):
"""Common base class for all V3 nodes.""" """Common base class for all V3 nodes."""
RELATIVE_PYTHON_MODULE = None RELATIVE_PYTHON_MODULE = None
SCHEMA = None
@classmethod
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
schema = cls.GET_SCHEMA()
# TODO: finish
return None
@classmethod
@abstractmethod
def DEFINE_SCHEMA(cls) -> SchemaV3:
"""
Override this function with one that returns a SchemaV3 instance.
"""
return None
DEFINE_SCHEMA = None
@classmethod
@abstractmethod
def EXECUTE(cls, **kwargs) -> NodeOutput:
pass
EXECUTE = None
@classmethod
def GET_SERIALIZERS(cls) -> list[Serializer]:
return []
def __init__(self):
self.__class__.VALIDATE_CLASS()
@classmethod
def VALIDATE_CLASS(cls):
if not callable(cls.DEFINE_SCHEMA):
raise Exception(f"No DEFINE_SCHEMA function was defined for node class {cls.__name__}.")
if not callable(cls.EXECUTE):
raise Exception(f"No execute function was defined for node class {cls.__name__}.")
############################################# #############################################
# V1 Backwards Compatibility code # V1 Backwards Compatibility code
#-------------------------------------------- #--------------------------------------------
@ -623,7 +684,7 @@ class ComfyNodeV3(ABC):
cls.GET_SCHEMA() cls.GET_SCHEMA()
return cls._OUTPUT_TOOLTIPS return cls._OUTPUT_TOOLTIPS
FUNCTION = "execute" FUNCTION = "EXECUTE"
@classmethod @classmethod
def INPUT_TYPES(cls) -> dict[str, dict]: def INPUT_TYPES(cls) -> dict[str, dict]:
@ -642,6 +703,7 @@ class ComfyNodeV3(ABC):
@classmethod @classmethod
def GET_SCHEMA(cls) -> SchemaV3: def GET_SCHEMA(cls) -> SchemaV3:
cls.VALIDATE_CLASS()
schema = cls.DEFINE_SCHEMA() schema = cls.DEFINE_SCHEMA()
if cls._DESCRIPTION is None: if cls._DESCRIPTION is None:
cls._DESCRIPTION = schema.description cls._DESCRIPTION = schema.description
@ -674,7 +736,7 @@ class ComfyNodeV3(ABC):
cls._RETURN_NAMES = output_name cls._RETURN_NAMES = output_name
cls._OUTPUT_IS_LIST = output_is_list cls._OUTPUT_IS_LIST = output_is_list
cls._OUTPUT_TOOLTIPS = output_tooltips cls._OUTPUT_TOOLTIPS = output_tooltips
cls.SCHEMA = schema
return schema return schema
@classmethod @classmethod
@ -716,31 +778,6 @@ class ComfyNodeV3(ABC):
#-------------------------------------------- #--------------------------------------------
############################################# #############################################
@classmethod
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
schema = cls.GET_SCHEMA()
# TODO: finish
return None
@classmethod
@abstractmethod
def DEFINE_SCHEMA(cls) -> SchemaV3:
"""
Override this function with one that returns a SchemaV3 instance.
"""
return None
DEFINE_SCHEMA = None
def __init__(self):
if self.DEFINE_SCHEMA is None:
raise Exception("No DEFINE_SCHEMA function was defined for this node.")
@abstractmethod
def execute(self, **kwargs) -> NodeOutput:
pass
# class ReturnedInputs: # class ReturnedInputs:
# def __init__(self): # def __init__(self):
# pass # pass
@ -857,19 +894,20 @@ class TestNode(ComfyNodeV3):
def DEFINE_SCHEMA(cls): def DEFINE_SCHEMA(cls):
return cls.SCHEMA return cls.SCHEMA
def execute(**kwargs): def EXECUTE(**kwargs):
pass pass
if __name__ == "__main__": if __name__ == "__main__":
print("hello there") print("hello there")
inputs: list[InputV3] = [ inputs: list[InputV3] = [
IntegerInput("tessfes", widgetType=IO.STRING),
IntegerInput("my_int"), IntegerInput("my_int"),
CustomInput("xyz", "XYZ"), CustomInput("xyz", "XYZ"),
CustomInput("model1", "MODEL_M"), CustomInput("model1", "MODEL_M"),
ImageInput("my_image"), ImageInput("my_image"),
FloatInput("my_float"), FloatInput("my_float"),
MultitypedInput("my_inputs", [CustomType("MODEL_M"), CustomType("XYZ")]), MultitypedInput("my_inputs", [StringInput, CustomType("MODEL_M"), CustomType("XYZ")]),
] ]
outputs: list[OutputV3] = [ outputs: list[OutputV3] = [

View File

@ -1,13 +1,18 @@
import torch import torch
from comfy_api.v3.io import ( from comfy_api.v3.io import (
ComfyNodeV3, SchemaV3, InputBehavior, NumberDisplay, ComfyNodeV3, SchemaV3, InputBehavior, NumberDisplay,
IntegerInput, MaskInput, ImageInput, ComboInput, CustomInput, IntegerInput, MaskInput, ImageInput, ComboInput, CustomInput, StringInput, CustomType,
IntegerOutput, ImageOutput, IntegerOutput, ImageOutput, MultitypedInput,
NodeOutput, NodeOutput, Hidden
) )
import logging
class V3TestNode(ComfyNodeV3): class V3TestNode(ComfyNodeV3):
def __init__(self):
self.hahajkunless = ";)"
@classmethod @classmethod
def DEFINE_SCHEMA(cls): def DEFINE_SCHEMA(cls):
return SchemaV3( return SchemaV3(
@ -17,7 +22,7 @@ class V3TestNode(ComfyNodeV3):
category="v3 nodes", category="v3 nodes",
inputs=[ inputs=[
ImageInput("image", display_name="new_image"), ImageInput("image", display_name="new_image"),
CustomInput("xyz", "XYZ"), CustomInput("xyz", "XYZ", behavior=InputBehavior.optional),
MaskInput("mask", behavior=InputBehavior.optional), MaskInput("mask", behavior=InputBehavior.optional),
IntegerInput("some_int", display_name="new_name", min=0, max=127, default=42, IntegerInput("some_int", display_name="new_name", min=0, max=127, default=42,
tooltip="My tooltip 😎", display_mode=NumberDisplay.slider), tooltip="My tooltip 😎", display_mode=NumberDisplay.slider),
@ -42,11 +47,20 @@ class V3TestNode(ComfyNodeV3):
outputs=[ outputs=[
IntegerOutput("int_output"), IntegerOutput("int_output"),
ImageOutput("img_output", display_name="img🖼", tooltip="This is an image"), ImageOutput("img_output", display_name="img🖼", tooltip="This is an image"),
],
hidden=[
], ],
is_output_node=True, is_output_node=True,
) )
def execute(self, image: torch.Tensor, xyz, some_int: int, combo: str, mask: torch.Tensor=None): @classmethod
def EXECUTE(cls, image: torch.Tensor, some_int: int, combo: str, xyz=None, mask: torch.Tensor=None):
if hasattr(cls, "hahajkunless"):
raise Exception("The 'cls' variable leaked instance state between runs!")
if hasattr(cls, "doohickey"):
raise Exception("The 'cls' variable leaked state on class properties between runs!")
cls.doohickey = "LOLJK"
return NodeOutput(some_int, image) return NodeOutput(some_int, image)

View File

@ -17,7 +17,7 @@ from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt,
from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.validation import validate_node_input from comfy_execution.validation import validate_node_input
from comfy_api.v3.io import NodeOutput from comfy_api.v3.io import NodeOutput, ComfyNodeV3, prepare_class_clone
class ExecutionResult(Enum): class ExecutionResult(Enum):
SUCCESS = 0 SUCCESS = 0
@ -183,7 +183,14 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
if execution_block is None: if execution_block is None:
if pre_execute_cb is not None and index is not None: if pre_execute_cb is not None and index is not None:
pre_execute_cb(index) pre_execute_cb(index)
results.append(getattr(obj, func)(**inputs)) # V3
if isinstance(obj, ComfyNodeV3):
type(obj).VALIDATE_CLASS()
class_clone = prepare_class_clone(obj)
results.append(type(obj).EXECUTE.__func__(class_clone, **inputs))
# V1
else:
results.append(getattr(obj, func)(**inputs))
else: else:
results.append(execution_block) results.append(execution_block)