Add V3-to-V1 compatibility on early V3 node definition and node_info in server.py

This commit is contained in:
Jedrzej Kosinski 2025-05-28 20:56:25 -07:00
parent 880f756dc1
commit 96c2e3856d
2 changed files with 374 additions and 130 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Union, Any from typing import Union, Any
from enum import Enum from enum import Enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
class InputBehavior(str, Enum): class InputBehavior(str, Enum):
required = "required" required = "required"
@ -45,6 +46,16 @@ class InputV3(IO_V3, io_type=None):
self.tooltip = tooltip self.tooltip = tooltip
self.lazy = lazy self.lazy = lazy
def as_dict_V1(self):
return prune_dict({
"display_name": self.display_name,
"tooltip": self.tooltip,
"lazy": self.lazy
})
def get_io_type_V1(self):
return self.io_type
class WidgetInputV3(InputV3, io_type=None): class WidgetInputV3(InputV3, io_type=None):
''' '''
Base class for a V3 Input with widget. Base class for a V3 Input with widget.
@ -57,6 +68,13 @@ class WidgetInputV3(InputV3, io_type=None):
self.socketless = socketless self.socketless = socketless
self.widgetType = widgetType self.widgetType = widgetType
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"default": self.default,
"socketless": self.socketless,
"widgetType": self.widgetType,
})
def CustomType(io_type: str) -> type[IO_V3]: def CustomType(io_type: str) -> type[IO_V3]:
name = f"{io_type}_IO_V3" name = f"{io_type}_IO_V3"
return type(name, (IO_V3,), {}, io_type=io_type) return type(name, (IO_V3,), {}, io_type=io_type)
@ -98,6 +116,12 @@ class BooleanInput(WidgetInputV3, io_type="BOOLEAN"):
self.label_off = label_off self.label_off = label_off
self.default: bool self.default: bool
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"label_on": self.label_on,
"label_off": self.label_off,
})
class IntegerInput(WidgetInputV3, io_type="INT"): class IntegerInput(WidgetInputV3, io_type="INT"):
''' '''
Integer input. Integer input.
@ -113,6 +137,15 @@ class IntegerInput(WidgetInputV3, io_type="INT"):
self.display_mode = display_mode self.display_mode = display_mode
self.default: int self.default: int
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"min": self.min,
"max": self.max,
"step": self.step,
"control_after_generate": self.control_after_generate,
"display_mode": self.display_mode,
})
class FloatInput(WidgetInputV3, io_type="FLOAT"): class FloatInput(WidgetInputV3, io_type="FLOAT"):
''' '''
Float input. Float input.
@ -129,6 +162,15 @@ class FloatInput(WidgetInputV3, io_type="FLOAT"):
self.display_mode = display_mode self.display_mode = display_mode
self.default: float self.default: float
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"min": self.min,
"max": self.max,
"step": self.step,
"round": self.round,
"display_mode": self.display_mode,
})
class StringInput(WidgetInputV3, io_type="STRING"): class StringInput(WidgetInputV3, io_type="STRING"):
''' '''
String input. String input.
@ -141,6 +183,12 @@ class StringInput(WidgetInputV3, io_type="STRING"):
self.placeholder = placeholder self.placeholder = placeholder
self.default: str self.default: str
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"multiline": self.multiline,
"placeholder": self.placeholder,
})
class ComboInput(WidgetInputV3, io_type="COMBO"): class ComboInput(WidgetInputV3, io_type="COMBO"):
'''Combo input (dropdown).''' '''Combo input (dropdown).'''
def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
@ -152,6 +200,13 @@ class ComboInput(WidgetInputV3, io_type="COMBO"):
self.control_after_generate = control_after_generate self.control_after_generate = control_after_generate
self.default: str self.default: str
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"multiselect": self.multiselect,
"options": self.options,
"control_after_generate": self.control_after_generate,
})
class MultiselectComboWidget(ComboInput, io_type="COMBO"): class MultiselectComboWidget(ComboInput, io_type="COMBO"):
'''Multiselect Combo input (dropdown for selecting potentially more than one value).''' '''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, options: list[str], display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
@ -163,6 +218,13 @@ class MultiselectComboWidget(ComboInput, io_type="COMBO"):
self.chip = chip self.chip = chip
self.default: list[str] self.default: list[str]
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"multiselect": self.multiselect,
"placeholder": self.placeholder,
"chip": self.chip,
})
class ImageInput(InputV3, io_type="IMAGE"): class ImageInput(InputV3, io_type="IMAGE"):
''' '''
Image input. Image input.
@ -205,6 +267,9 @@ class MultitypedInput(InputV3, io_type="COMFY_MULTITYPED_V3"):
io_types.append(x) io_types.append(x)
return io_types return io_types
def get_io_type_V1(self):
return ",".join(x.io_type for x in self.io_types)
class OutputV3: class OutputV3:
def __init__(self, id: str, display_name: str=None, tooltip: str=None, def __init__(self, id: str, display_name: str=None, tooltip: str=None,
@ -294,83 +359,52 @@ class Hidden(str, Enum):
api_key_comfy_org = "API_KEY_COMFY_ORG" api_key_comfy_org = "API_KEY_COMFY_ORG"
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend.""" """API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
# '''
# Request hidden value based on hidden_var key.
# '''
# def __init__(self, hidden_var: str):
# self.hidden_var = hidden_var
# NOTE: does this exist?
# class HiddenNodeId(Hidden):
# """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
# def __init__(self):
# super().__init__("NODE_ID")
# class HiddenUniqueId(Hidden):
# """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
# def __init__(self):
# super().__init__("UNIQUE_ID")
# class HiddenPrompt(Hidden):
# """PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
# def __init__(self):
# super().__init__("PROMPT")
# class HiddenExtraPngInfo(Hidden):
# """EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
# def __init__(self):
# super().__init__("EXTRA_PNGINFO")
# class HiddenDynPrompt(Hidden):
# """DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
# def __init__(self):
# super().__init__("DYNPROMPT")
# class HiddenAuthTokenComfyOrg(Hidden):
# """Token acquired from signing into a ComfyOrg account on frontend."""
# def __init__(self):
# super().__init__("AUTH_TOKEN_COMFY_ORG")
# class HiddenApiKeyComfyOrg(Hidden):
# """API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
# def __init__(self):
# super().__init__("API_KEY_COMFY_ORG")
# class HiddenParam: @dataclass
# def __init__(self): class NodeInfoV1:
# pass input: dict=None
input_order: dict[str, list[str]]=None
# def __init_subclass__(cls, hidden_var, **kwargs): output: list[str]=None
# cls.hidden_var = hidden_var output_is_list: list[bool]=None
# super().__init_subclass__(**kwargs) output_name: list[str]=None
output_tooltips: list[str]=None
# def Hidden(hidden_var: str) -> type[HiddenParam]: name: str=None
# return type(f"{hidden_var}_HIDDEN", (HiddenParam,), {}, hidden_var=hidden_var) display_name: str=None
description: str=None
python_module: Any=None
category: str=None
output_node: bool=None
deprecated: bool=None
experimental: bool=None
api_node: bool=None
def as_pruned_dict(dataclass_obj):
'''Return dict of dataclass object with pruned None values.'''
return prune_dict(asdict(dataclass_obj))
def prune_dict(d: dict):
return {k: v for k,v in d.items() if v is not None}
@dataclass
class SchemaV3: class SchemaV3:
def __init__(self, """Definition of V3 node properties."""
category: str,
inputs: list[InputV3], node_id: str
outputs: list[OutputV3]=None, """ID of node - should be globally unique. If this is a custom node, add a prefix or postfix to avoid name clashes."""
hidden: list[Hidden]=None, display_name: str = None
description: str="", """Display name of node."""
is_input_list: bool = False, category: str = "sd"
is_output_node: bool=False, """The category of the node, as per the "Add Node" menu."""
is_deprecated: bool=False, inputs: list[InputV3]=None
is_experimental: bool=False, outputs: list[OutputV3]=None
is_api_node: bool=False, hidden: list[Hidden]=None
): description: str=""
self.category = category """Node description, shown as a tooltip when hovering over the node."""
"""The category of the node, as per the "Add Node" menu.""" is_input_list: bool = False
self.inputs = inputs """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
self.outputs = outputs
self.hidden = hidden
self.description = description
"""Node description, shown as a tooltip when hovering over the node."""
self.is_input_list = is_input_list
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``. All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
@ -380,8 +414,8 @@ class SchemaV3:
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
""" """
self.is_output_node = is_output_node is_output_node: bool=False
"""Flags this node as an output node, causing any inputs it requires to be executed. """Flags this node as an output node, causing any inputs it requires to be executed.
If a node is not connected to any output nodes, that node will not be executed. Usage:: If a node is not connected to any output nodes, that node will not be executed. Usage::
@ -393,76 +427,272 @@ class SchemaV3:
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
""" """
self.is_deprecated = is_deprecated is_deprecated: bool=False
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node.""" """Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
self.is_experimental = is_experimental is_experimental: bool=False
"""Flags a node as experimental, informing users that it may change or not work as expected.""" """Flags a node as experimental, informing users that it may change or not work as expected."""
self.is_api_node = is_api_node is_api_node: bool=False
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
# class SchemaV3Class:
# def __init__(self,
# node_id: str,
# node_name: str,
# category: str,
# inputs: list[InputV3],
# outputs: list[OutputV3]=None,
# hidden: list[Hidden]=None,
# description: str="",
# is_input_list: bool = False,
# is_output_node: bool=False,
# is_deprecated: bool=False,
# is_experimental: bool=False,
# is_api_node: bool=False,
# ):
# self.node_id = node_id
# """ID of node - should be globally unique. If this is a custom node, add a prefix or postfix to avoid name clashes."""
# self.node_name = node_name
# """Display name of node."""
# self.category = category
# """The category of the node, as per the "Add Node" menu."""
# self.inputs = inputs
# self.outputs = outputs
# self.hidden = hidden
# self.description = description
# """Node description, shown as a tooltip when hovering over the node."""
# self.is_input_list = is_input_list
# """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
# All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
# From the docs:
# A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
# Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
# """
# self.is_output_node = is_output_node
# """Flags this node as an output node, causing any inputs it requires to be executed.
# If a node is not connected to any output nodes, that node will not be executed. Usage::
# OUTPUT_NODE = True
# From the docs:
# By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
# Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
# """
# self.is_deprecated = is_deprecated
# """Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
# self.is_experimental = is_experimental
# """Flags a node as experimental, informing users that it may change or not work as expected."""
# self.is_api_node = is_api_node
# """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
class classproperty(object):
def __init__(self, f):
self.f = f
def __get__(self, obj, owner):
return self.f(owner)
class ComfyNodeV3(ABC): class ComfyNodeV3(ABC):
"""Common base class for all V3 nodes."""
#############################################
# V1 Backwards Compatibility code
#--------------------------------------------
_DESCRIPTION = None
@classproperty
def DESCRIPTION(cls):
if not cls._DESCRIPTION:
cls.GET_SCHEMA()
return cls._DESCRIPTION
_CATEGORY = None
@classproperty
def CATEGORY(cls):
if not cls._CATEGORY:
cls.GET_SCHEMA()
return cls._CATEGORY
_EXPERIMENTAL = None
@classproperty
def EXPERIMENTAL(cls):
if not cls._EXPERIMENTAL:
cls.GET_SCHEMA()
return cls._EXPERIMENTAL
_DEPRECATED = None
@classproperty
def DEPRECATED(cls):
if not cls._DEPRECATED:
cls.GET_SCHEMA()
return cls._DEPRECATED
_API_NODE = None
@classproperty
def API_NODE(cls):
if not cls._API_NODE:
cls.GET_SCHEMA()
return cls._API_NODE
_OUTPUT_NODE = None
@classproperty
def OUTPUT_NODE(cls):
if not cls._OUTPUT_NODE:
cls.GET_SCHEMA()
return cls._OUTPUT_NODE
_INPUT_IS_LIST = None
@classproperty
def INPUT_IS_LIST(cls):
if not cls._INPUT_IS_LIST:
cls.GET_SCHEMA()
return cls._INPUT_IS_LIST
_OUTPUT_IS_LIST = None
@classproperty
def OUTPUT_IS_LIST(cls):
if not cls._OUTPUT_IS_LIST:
cls.GET_SCHEMA()
return cls._OUTPUT_IS_LIST
_RETURN_TYPES = None
@classproperty
def RETURN_TYPES(cls):
if not cls._RETURN_TYPES:
cls.GET_SCHEMA()
return cls._RETURN_TYPES
_RETURN_NAMES = None
@classproperty
def RETURN_NAMES(cls):
if not cls._RETURN_NAMES:
cls.GET_SCHEMA()
return cls._RETURN_NAMES
_OUTPUT_TOOLTIPS = None
@classproperty
def OUTPUT_TOOLTIPS(cls):
if not cls._OUTPUT_TOOLTIPS:
cls.GET_SCHEMA()
return cls._OUTPUT_TOOLTIPS
FUNCTION = "execute"
@classmethod
def INPUT_TYPES(cls) -> dict[str, dict]:
schema = cls.DEFINE_SCHEMA()
# for V1, make inputs be a dict with potential keys {required, optional, hidden}
input = {
"required": {}
}
if schema.inputs:
for i in schema.inputs:
input.setdefault(i.behavior.value, {})[i.id] = (i.get_io_type_V1(), i.as_dict_V1())
if schema.hidden:
for hidden in schema.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
return input
@classmethod @classmethod
def GET_SCHEMA(cls) -> SchemaV3: def GET_SCHEMA(cls) -> SchemaV3:
schema = cls.DEFINE_SCHEMA()
if cls._DESCRIPTION is None:
cls._DESCRIPTION = schema.description
if cls._CATEGORY is None:
cls._CATEGORY = schema.category
if cls._EXPERIMENTAL is None:
cls._EXPERIMENTAL = schema.is_experimental
if cls._DEPRECATED is None:
cls._DEPRECATED = schema.is_deprecated
if cls._API_NODE is None:
cls._API_NODE = schema.is_api_node
if cls._OUTPUT_NODE is None:
cls._OUTPUT_NODE = schema.is_output_node
if cls._INPUT_IS_LIST is None:
cls._INPUT_IS_LIST = schema.is_input_list
if cls._RETURN_TYPES is None:
output = []
output_name = []
output_is_list = []
output_tooltips = []
if schema.outputs:
for o in schema.outputs:
output.append(o.io_type)
output_name.append(o.display_name if o.display_name else o.io_type)
output_is_list.append(o.is_output_list)
output_tooltips.append(o.tooltip if o.tooltip else None)
cls._RETURN_TYPES = output
cls._RETURN_NAMES = output_name
cls._OUTPUT_IS_LIST = output_is_list
cls._OUTPUT_TOOLTIPS = output_tooltips
return schema
@classmethod
def GET_NODE_INFO_V1(cls) -> dict[str, Any]:
schema = cls.GET_SCHEMA()
# get V1 inputs
input = cls.INPUT_TYPES()
# create separate lists from output fields
output = []
output_is_list = []
output_name = []
output_tooltips = []
if schema.outputs:
for o in schema.outputs:
output.append(o.io_type)
output_is_list.append(o.is_output_list)
output_name.append(o.display_name if o.display_name else o.io_type)
output_tooltips.append(o.tooltip if o.tooltip else None)
info = NodeInfoV1(
input=input,
input_order={key: list(value.keys()) for (key, value) in input.items()},
output=output,
output_is_list=output_is_list,
output_name=output_name,
output_tooltips=output_tooltips,
name=schema.node_id,
display_name=schema.display_name,
category=schema.category,
description=schema.description,
output_node=schema.is_output_node,
deprecated=schema.is_deprecated,
experimental=schema.is_experimental,
api_node=schema.is_api_node,
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes")
)
return asdict(info)
#--------------------------------------------
#############################################
@classmethod
@abstractmethod
def DEFINE_SCHEMA(cls) -> SchemaV3:
""" """
Override this function with one that returns a SchemaV3 instance. Override this function with one that returns a SchemaV3 instance.
""" """
return None return None
GET_SCHEMA = None DEFINE_SCHEMA = None
def __init__(self): def __init__(self):
if self.GET_SCHEMA is None: if self.DEFINE_SCHEMA is None:
raise Exception("No GET_SCHEMA function was defined for this node.") raise Exception("No DEFINE_SCHEMA function was defined for this node.")
@abstractmethod @abstractmethod
def execute(self, inputs, outputs, hidden, **kwargs): def execute(self, inputs, outputs, hidden, **kwargs):
pass pass
# @classmethod
# @abstractmethod
# def INPUTS(cls) -> list[InputV3]:
# pass
# @classmethod
# @abstractmethod
# def OUTPUTS(cls) -> list[OutputV3]:
# pass
# @abstractmethod
# def execute(self, inputs, outputs, hidden):
# pass
# class ComfyNodeV3:
# INPUTS = [
# ImageInput("image"),
# IntegerInput("count", min=1, max=6),
# ]
# OUTPUTS = [
# ]
# OUTPUTS = [
# ImageOutput(),
# ]
# class CustomInput(InputV3):
# def __init__(self, id: str, io_type: str):
# super().__init__(id)
# IO_TYPE = IO_TYPE
# class AnimateDiffModelInput(InputV3, io_type="MODEL_M"):
# def __init__(self):
# pass
# def execute(inputs, outputs, hidden):
# pass
class ReturnedInputs: class ReturnedInputs:
def __init__(self): def __init__(self):
@ -484,8 +714,12 @@ class UINodeOutput:
class TestNode(ComfyNodeV3): class TestNode(ComfyNodeV3):
SCHEMA = SchemaV3( SCHEMA = SchemaV3(
node_id="TestNode_v3",
display_name="Test Node (V3)",
category="v3_test", category="v3_test",
inputs=[], inputs=[IntegerInput("my_int")],
outputs=[ImageOutput("image_output")],
hidden=[Hidden.api_key_comfy_org, Hidden.auth_token_comfy_org, Hidden.unique_id]
) )
# @classmethod # @classmethod
@ -493,12 +727,13 @@ class TestNode(ComfyNodeV3):
# return cls.SCHEMA # return cls.SCHEMA
@classmethod @classmethod
def GET_SCHEMA(cls): def DEFINE_SCHEMA(cls):
return cls.SCHEMA return cls.SCHEMA
def execute(**kwargs): def execute(**kwargs):
pass pass
if __name__ == "__main__": if __name__ == "__main__":
print("hello there") print("hello there")
inputs: list[InputV3] = [ inputs: list[InputV3] = [
@ -518,10 +753,16 @@ if __name__ == "__main__":
for c in inputs: for c in inputs:
if isinstance(c, MultitypedInput): if isinstance(c, MultitypedInput):
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}, {[x.io_type for x in c.io_types]}") print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}, {[x.io_type for x in c.io_types]}")
print(c.get_io_type_V1())
else: else:
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}") print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}")
for c in outputs: for c in outputs:
print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}") print(f"{c}, {type(c)}, {type(c).io_type}, {c.id}")
zzz = TestNode() zz = TestNode()
print(zz.GET_NODE_INFO_V1())
# aa = NodeInfoV1()
# print(asdict(aa))
# print(as_pruned_dict(aa))

View File

@ -29,6 +29,7 @@ import comfy.model_management
import node_helpers import node_helpers
from comfyui_version import __version__ from comfyui_version import __version__
from app.frontend_management import FrontendManager from app.frontend_management import FrontendManager
from comfy_api.v3.io import ComfyNodeV3
from app.user_manager import UserManager from app.user_manager import UserManager
from app.model_manager import ModelFileManager from app.model_manager import ModelFileManager
@ -555,6 +556,8 @@ class PromptServer():
def node_info(node_class): def node_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
if isinstance(obj_class, ComfyNodeV3):
return obj_class.GET_NODE_INFO_V1()
info = {} info = {}
info['input'] = obj_class.INPUT_TYPES() info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}