Add comfytype decorator, convert all relevant v3_01 types to follow new convention, make v1 test node have xyz be optional

This commit is contained in:
kosinkadink1@gmail.com 2025-06-13 04:06:06 -07:00
parent cf7312d82c
commit 54e0d6b161
3 changed files with 243 additions and 209 deletions

View File

@ -1,22 +1,16 @@
from __future__ import annotations
from typing import Any, Literal, TYPE_CHECKING, TypeVar
from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast
from enum import Enum
from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
from comfy.comfy_types.node_typing import IO
# NOTE: these imports here are mostly for keeping execution.py happy with type inheritance
from comfy_api.v3.io import ComfyNodeV3 as BASE_CV3
from comfy_api.v3.io import NodeOutput as BASE_NO
# if TYPE_CHECKING:
import torch
class InputBehavior(str, Enum):
'''Likely deprecated; required/optional can be a bool, unlikely to be more categories that fit.'''
required = "required"
optional = "optional"
class FolderType(str, Enum):
input = "input"
output = "output"
@ -62,23 +56,75 @@ class NumberDisplay(str, Enum):
slider = "slider"
class ComfyType:
Type = Any
io_type: str = None
Input = None
Output = None
# NOTE: this is a workaround to make the decorator return the correct type
T = TypeVar("T", bound=type)
def comfytype(io_type: str, **kwargs):
'''
Decorator to mark nested classes as ComfyType; io_type will be bound to the class.
A ComfyType may have the following attributes:
- Type = <type hint here>
- class Input(InputV3): ...
- class Output(OutputV3): ...
'''
def decorator(cls: T) -> T:
if not isinstance(cls, ComfyType):
# copy class attributes except for special ones that shouldn't be in type()
cls_dict = {
k: v for k, v in cls.__dict__.items()
if k not in ('__dict__', '__weakref__', '__module__', '__doc__')
}
# new class
new_cls: ComfyType = type(
cls.__name__,
(cls, ComfyType),
cls_dict
)
# metadata preservation
new_cls.__module__ = cls.__module__
new_cls.__doc__ = cls.__doc__
# assign ComfyType attributes, if needed
# NOTE: do we need __ne__ trick for io_type? (see IO.__ne__ for details)
else:
new_cls = cls
new_cls.io_type = io_type
if new_cls.Input is not None:
new_cls.Input.Parent = new_cls
if new_cls.Output is not None:
new_cls.Output.Parent = new_cls
return new_cls
return decorator
class IO_V3:
'''
Base class for V3 Inputs and Outputs.
'''
Type = Any
Parent: ComfyType = None
def __init__(self):
pass
def __init_subclass__(cls, io_type: IO | str, **kwargs):
# TODO: do we need __ne__ trick for io_type? (see IO.__ne__ for details)
cls.io_type = io_type
super().__init_subclass__(**kwargs)
# def __init_subclass__(cls, io_type: IO | str, **kwargs):
# # TODO: do we need __ne__ trick for io_type? (see IO.__ne__ for details)
# cls.io_type = io_type
# super().__init_subclass__(**kwargs)
@property
def io_type(self):
return self.Parent.io_type
class InputV3(IO_V3, io_type=None):
@property
def Type(self):
return self.Parent.Type
class InputV3(IO_V3):
'''
Base class for a V3 Input.
'''
@ -100,7 +146,7 @@ class InputV3(IO_V3, io_type=None):
def get_io_type_V1(self):
return self.io_type
class WidgetInputV3(InputV3, io_type=None):
class WidgetInputV3(InputV3):
'''
Base class for a V3 Input with widget.
'''
@ -119,7 +165,7 @@ class WidgetInputV3(InputV3, io_type=None):
"widgetType": self.widgetType,
})
class OutputV3(IO_V3, io_type=None):
class OutputV3(IO_V3):
def __init__(self, id: str, display_name: str=None, tooltip: str=None,
is_output_list=False):
self.id = id
@ -156,10 +202,11 @@ def CustomOutput(id: str, io_type: IO | str, display_name: str=None, tooltip: st
return type(f"{io_type}Output", (OutputV3,), {}, io_type=io_type)(**input_kwargs)
@comfytype(io_type=IO.BOOLEAN)
class Boolean:
Type = bool
class Input(WidgetInputV3, io_type=IO.BOOLEAN):
class Input(WidgetInputV3):
'''Boolean input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: bool=None, label_on: str=None, label_off: str=None,
@ -175,13 +222,14 @@ class Boolean:
"label_off": self.label_off,
})
class Output(OutputV3, io_type=IO.BOOLEAN):
class Output(OutputV3):
...
class Integer:
@comfytype(io_type=IO.INT)
class Int:
Type = int
class Input(WidgetInputV3, io_type=IO.INT):
class Input(WidgetInputV3):
'''Integer input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
@ -203,13 +251,14 @@ class Integer:
"display": self.display_mode,
})
class Output(OutputV3, io_type=IO.INT):
class Output(OutputV3):
...
@comfytype(io_type=IO.FLOAT)
class Float:
Type = float
class Input(WidgetInputV3, io_type=IO.FLOAT):
class Input(WidgetInputV3):
'''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
@ -232,13 +281,14 @@ class Float:
"display": self.display_mode,
})
class Output(OutputV3, io_type=IO.FLOAT):
class Output(OutputV3):
...
@comfytype(io_type=IO.STRING)
class String:
Type = str
class Input(WidgetInputV3, io_type=IO.STRING):
class Input(WidgetInputV3):
'''String input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
multiline=False, placeholder: str=None, default: int=None,
@ -254,306 +304,292 @@ class String:
"placeholder": self.placeholder,
})
class Output(OutputV3, io_type=IO.STRING):
class Output(OutputV3):
...
class ComboInput(WidgetInputV3, io_type=IO.COMBO):
'''Combo input (dropdown).'''
@comfytype(io_type=IO.COMBO)
class Combo:
Type = str
def __init__(self, id: str, options: list[str]=None, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: str=None, control_after_generate: bool=None,
image_upload: bool=None, image_folder: FolderType=None,
remote: RemoteOptions=None,
socketless: bool=None, widgetType: str=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, widgetType)
self.multiselect = False
self.options = options
self.control_after_generate = control_after_generate
self.image_upload = image_upload
self.image_folder = image_folder
self.remote = remote
self.default: str
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"multiselect": self.multiselect,
"options": self.options,
"control_after_generate": self.control_after_generate,
"image_upload": self.image_upload,
"image_folder": self.image_folder.value if self.image_folder else None,
"remote": self.remote.as_dict() if self.remote else None,
})
class Input(WidgetInputV3):
'''Combo input (dropdown).'''
Type = str
def __init__(self, id: str, options: list[str]=None, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: str=None, control_after_generate: bool=None,
image_upload: bool=None, image_folder: FolderType=None,
remote: RemoteOptions=None,
socketless: bool=None, widgetType: str=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, widgetType)
self.multiselect = False
self.options = options
self.control_after_generate = control_after_generate
self.image_upload = image_upload
self.image_folder = image_folder
self.remote = remote
self.default: str
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"multiselect": self.multiselect,
"options": self.options,
"control_after_generate": self.control_after_generate,
"image_upload": self.image_upload,
"image_folder": self.image_folder.value if self.image_folder else None,
"remote": self.remote.as_dict() if self.remote else None,
})
class MultiselectComboWidget(ComboInput, io_type=IO.COMBO):
@comfytype(io_type=IO.COMBO)
class MultiCombo:
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
socketless: bool=None, widgetType: str=None):
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless, widgetType)
self.multiselect = True
self.placeholder = placeholder
self.chip = chip
self.default: list[str]
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"multiselect": self.multiselect,
"placeholder": self.placeholder,
"chip": self.chip,
})
Type = list[str]
class Input(Combo.Input):
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
socketless: bool=None, widgetType: str=None):
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless, widgetType)
self.multiselect = True
self.placeholder = placeholder
self.chip = chip
self.default: list[str]
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
"multiselect": self.multiselect,
"placeholder": self.placeholder,
"chip": self.chip,
})
@comfytype(io_type=IO.IMAGE)
class Image:
Type = torch.Tensor
class Input(InputV3, io_type=IO.IMAGE):
class Input(InputV3):
'''Image input.'''
class Output(OutputV3, io_type=IO.IMAGE):
class Output(OutputV3):
...
@comfytype(io_type=IO.MASK)
class Mask:
Type = torch.Tensor
class Input(InputV3, io_type=IO.MASK):
class Input(InputV3):
'''Mask input.'''
class Output(OutputV3, io_type=IO.MASK):
class Output(OutputV3):
...
@comfytype(io_type=IO.LATENT)
class Latent:
Type = Any # TODO: make Type a TypedDict
class Input(InputV3, io_type=IO.LATENT):
class Input(InputV3):
'''Latent input.'''
class Output(OutputV3, io_type=IO.LATENT):
class Output(OutputV3):
...
@comfytype(io_type=IO.CONDITIONING)
class Conditioning:
Type = Any
class Input(InputV3, io_type=IO.CONDITIONING):
class Input(InputV3):
'''Conditioning input.'''
class Output(OutputV3, io_type=IO.CONDITIONING):
class Output(OutputV3):
...
@comfytype(io_type=IO.SAMPLER)
class Sampler:
Type = Any
class Input(InputV3, io_type=IO.SAMPLER):
class Input(InputV3):
'''Sampler input.'''
class Output(OutputV3, io_type=IO.SAMPLER):
class Output(OutputV3):
...
@comfytype(io_type=IO.SIGMAS)
class Sigmas:
Type = Any
class Input(InputV3, io_type=IO.SIGMAS):
class Input(InputV3):
'''Sigmas input.'''
class Output(OutputV3, io_type=IO.SIGMAS):
class Output(OutputV3):
...
@comfytype(io_type=IO.NOISE)
class Noise:
Type = Any
class Input(InputV3, io_type=IO.NOISE):
class Input(InputV3):
'''Noise input.'''
class Output(OutputV3, io_type=IO.NOISE):
class Output(OutputV3):
...
@comfytype(io_type=IO.GUIDER)
class Guider:
Type = Any
class Input(InputV3, io_type=IO.GUIDER):
class Input(InputV3):
'''Guider input.'''
class Output(OutputV3, io_type=IO.GUIDER):
class Output(OutputV3):
...
@comfytype(io_type=IO.CLIP)
class Clip:
Type = Any
class Input(InputV3, io_type=IO.CLIP):
class Input(InputV3):
'''Clip input.'''
class Output(OutputV3, io_type=IO.CLIP):
class Output(OutputV3):
...
@comfytype(io_type=IO.CONTROL_NET)
class ControlNet:
Type = Any
class Input(InputV3, io_type=IO.CONTROL_NET):
class Input(InputV3):
'''ControlNet input.'''
class Output(OutputV3, io_type=IO.CONTROL_NET):
class Output(OutputV3):
...
@comfytype(io_type=IO.VAE)
class Vae:
Type = Any
class Input(InputV3, io_type=IO.VAE):
class Input(InputV3):
'''Vae input.'''
class Output(OutputV3, io_type=IO.VAE):
class Output(OutputV3):
...
@comfytype(io_type=IO.MODEL)
class Model:
Type = Any
class Input(InputV3, io_type=IO.MODEL):
class Input(InputV3):
'''Model input.'''
class Output(OutputV3, io_type=IO.MODEL):
class Output(OutputV3):
...
@comfytype(io_type=IO.CLIP_VISION)
class ClipVision:
Type = Any
class Input(InputV3, io_type=IO.CLIP_VISION):
class Input(InputV3):
'''ClipVision input.'''
class Output(OutputV3, io_type=IO.CLIP_VISION):
class Output(OutputV3):
...
@comfytype(io_type=IO.CLIP_VISION_OUTPUT)
class ClipVisionOutput:
Type = Any
class Input(InputV3, io_type=IO.CLIP_VISION_OUTPUT):
'''CLipVisionOutput input.'''
class Output(OutputV3, io_type=IO.CLIP_VISION_OUTPUT):
class Input(InputV3):
'''ClipVisionOutput input.'''
class Output(OutputV3):
...
@comfytype(io_type=IO.STYLE_MODEL)
class StyleModel:
Type = Any
class Input(InputV3, io_type=IO.STYLE_MODEL):
class Input(InputV3):
'''StyleModel input.'''
class Output(OutputV3, io_type=IO.STYLE_MODEL):
class Output(OutputV3):
...
@comfytype(io_type=IO.GLIGEN)
class Gligen:
Type = Any
class Input(InputV3, io_type=IO.GLIGEN):
class Input(InputV3):
'''Gligen input.'''
class Output(OutputV3, io_type=IO.GLIGEN):
class Output(OutputV3):
...
@comfytype(io_type=IO.UPSCALE_MODEL)
class UpscaleModel:
Type = Any
class Input(InputV3, io_type=IO.UPSCALE_MODEL):
class Input(InputV3):
'''UpscaleModel input.'''
class Output(OutputV3, io_type=IO.UPSCALE_MODEL):
class Output(OutputV3):
...
@comfytype(io_type=IO.AUDIO)
class Audio:
Type = Any
class Input(InputV3, io_type=IO.AUDIO):
class Input(InputV3):
'''Audio input.'''
class Output(OutputV3, io_type=IO.AUDIO):
class Output(OutputV3):
...
@comfytype(io_type=IO.POINT)
class Point:
Type = Any
class Input(InputV3, io_type=IO.POINT):
class Input(InputV3):
'''Point input.'''
class Output(OutputV3, io_type=IO.POINT):
class Output(OutputV3):
...
@comfytype(io_type=IO.FACE_ANALYSIS)
class FaceAnalysis:
Type = Any
class Input(InputV3, io_type=IO.FACE_ANALYSIS):
class Input(InputV3):
'''FaceAnalysis input.'''
class Output(OutputV3, io_type=IO.FACE_ANALYSIS):
class Output(OutputV3):
...
@comfytype(io_type=IO.BBOX)
class BBOX:
Type = Any
class Input(InputV3, io_type=IO.BBOX):
class Input(InputV3):
'''Bbox input.'''
class Output(OutputV3, io_type=IO.BBOX):
class Output(OutputV3):
...
@comfytype(io_type=IO.SEGS)
class SEGS:
Type = Any
class Input(InputV3, io_type=IO.SEGS):
class Input(InputV3):
'''SEGS input.'''
class Output(OutputV3, io_type=IO.SEGS):
class Output(OutputV3):
...
@comfytype(io_type=IO.VIDEO)
class Video:
Type = Any
class Input(InputV3, io_type=IO.VIDEO):
class Input(InputV3):
'''Video input.'''
class Output(OutputV3, io_type=IO.VIDEO):
class Output(OutputV3):
...
class MultitypedInput(InputV3, io_type="COMFY_MULTITYPED_V3"):
'''
Input that permits more than one input type.
'''
def __init__(self, id: str, io_types: list[type[IO_V3] | InputV3 | IO |str], display_name: str=None, optional=False, tooltip: str=None,):
super().__init__(id, display_name, optional, tooltip)
self._io_types = io_types
@property
def io_types(self) -> list[type[InputV3]]:
@comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType:
Type = Any
class Input(InputV3):
'''
Returns list of InputV3 class types permitted.
Input that permits more than one input type.
'''
io_types = []
for x in self._io_types:
if not is_class(x):
io_types.append(type(x))
else:
io_types.append(x)
return io_types
def get_io_type_V1(self):
return ",".join(x.io_type for x in self.io_types)
def __init__(self, id: str, io_types: list[type[ComfyType] | ComfyType | IO |str], display_name: str=None, optional=False, tooltip: str=None,):
super().__init__(id, display_name, optional, tooltip)
self._io_types = io_types
@property
def io_types(self) -> list[type[InputV3]]:
'''
Returns list of InputV3 class types permitted.
'''
io_types = []
for x in self._io_types:
if not is_class(x):
io_types.append(type(x))
else:
io_types.append(x)
return io_types
def get_io_type_V1(self):
return ",".join(x.io_type for x in self.io_types)
class DynamicInput(InputV3, io_type=None):
class DynamicInput(InputV3):
'''
Abstract class for dynamic input registration.
'''
def __init__(self, io_type: str, id: str, display_name: str=None):
super().__init__(io_type, id, display_name)
class DynamicOutput(OutputV3, io_type=None):
class DynamicOutput(OutputV3):
'''
Abstract class for dynamic output registration.
'''
def __init__(self, io_type: str, id: str, display_name: str=None):
super().__init__(io_type, id, display_name)
class AutoGrowDynamicInput(DynamicInput, io_type="COMFY_MULTIGROW_V3"):
# io_type="COMFY_MULTIGROW_V3"
class AutoGrowDynamicInput(DynamicInput):
'''
Dynamic Input that adds another template_input each time one is provided.
Additional inputs are forced to have 'InputBehavior.optional'.
Additional inputs are forced to have 'optional=True'.
'''
def __init__(self, id: str, template_input: InputV3, min: int=1, max: int=None):
super().__init__("AutoGrowDynamicInput", id)
@ -565,7 +601,8 @@ class AutoGrowDynamicInput(DynamicInput, io_type="COMFY_MULTIGROW_V3"):
self.min = min
self.max = max
class ComboDynamicInput(DynamicInput, io_type="COMFY_COMBODYNAMIC_V3"):
# io_type="COMFY_COMBODYNAMIC_V3"
class ComboDynamicInput(DynamicInput):
def __init__(self, id: str):
pass
@ -576,7 +613,6 @@ class Hidden(str, Enum):
'''
Enumerator for requesting hidden variables in nodes.
'''
unique_id = "UNIQUE_ID"
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
prompt = "PROMPT"
@ -979,14 +1015,6 @@ class ComfyNodeV3(BASE_CV3):
#--------------------------------------------
#############################################
# class ReturnedInputs:
# def __init__(self):
# pass
# class ReturnedOutputs:
# def __init__(self):
# pass
class NodeOutput(BASE_NO):
'''
@ -1074,6 +1102,11 @@ class UIText(UIOutput):
return {"text": (self.value,)}
def create_image_preview(image: Image.Type) -> UIImages:
# TODO: finish, right now is just Cursor's hallucination
return UIImages([SavedResult("preview.png", "comfy_org", FolderType.output)])
class TestNode(ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
@ -1081,7 +1114,7 @@ class TestNode(ComfyNodeV3):
node_id="TestNode_v3",
display_name="Test Node (V3)",
category="v3_test",
inputs=[Integer.Input("my_int"),
inputs=[Int.Input("my_int"),
#AutoGrowDynamicInput("growing", Image.Input),
Mask.Input("thing"),
],
@ -1096,8 +1129,8 @@ class TestNode(ComfyNodeV3):
if __name__ == "__main__":
print("hello there")
inputs: list[InputV3] = [
Integer.Input("tessfes", widgetType=IO.STRING),
Integer.Input("my_int"),
Int.Input("tessfes", widgetType=IO.STRING),
Int.Input("my_int"),
CustomInput("xyz", "XYZ"),
CustomInput("model1", "MODEL_M"),
Image.Input("my_image"),

View File

@ -9,13 +9,13 @@ class TestNode(ComfyNodeABC):
return {
"required": {
"image": (IO.IMAGE,),
"xyz": ("XYZ",),
"some_int": (IO.INT, {"display_name": "new_name",
"min": 0, "max": 127, "default": 42,
"tooltip": "My tooltip 😎", "display": "slider"}),
"combo": (IO.COMBO, {"options": ["a", "b", "c"], "tooltip": "This is a combo input"}),
},
"optional": {
"xyz": ("XYZ",),
"mask": (IO.MASK,),
}
}
@ -29,7 +29,7 @@ class TestNode(ComfyNodeABC):
CATEGORY = "v3 nodes"
def do_thing(self, image: torch.Tensor, xyz, some_int: int, combo: str, mask: torch.Tensor=None):
def do_thing(self, image: torch.Tensor, some_int: int, combo: str, xyz=None, mask: torch.Tensor=None):
return (some_int, image)

View File

@ -3,14 +3,14 @@ from comfy_api.v3_01 import io
import logging
@io.comfytype(io_type="XYZ")
class XYZ:
Type = tuple[int,str]
class Input(io.InputV3, io_type="XYZ"):
class Input(io.InputV3):
...
class Output(io.OutputV3, io_type="XYZ"):
class Output(io.OutputV3):
...
class V3TestNode(io.ComfyNodeV3):
def __init__(self):
@ -28,9 +28,10 @@ class V3TestNode(io.ComfyNodeV3):
XYZ.Input("xyz", optional=True),
#CustomInput("xyz", "XYZ", optional=True),
io.Mask.Input("mask", optional=True),
io.Integer.Input("some_int", display_name="new_name", min=0, max=127, default=42,
io.Int.Input("some_int", display_name="new_name", min=0, max=127, default=42,
tooltip="My tooltip 😎", display_mode=io.NumberDisplay.slider),
io.ComboInput("combo", options=["a", "b", "c"], tooltip="This is a combo input"),
io.Combo.Input("combo", options=["a", "b", "c"], tooltip="This is a combo input"),
io.MultiCombo.Input("combo2", options=["a","b","c"]),
# ComboInput("combo", image_upload=True, image_folder=FolderType.output,
# remote=RemoteOptions(
# route="/internal/files/output",
@ -49,7 +50,7 @@ class V3TestNode(io.ComfyNodeV3):
# ]]
],
outputs=[
io.Integer.Output("int_output"),
io.Int.Output("int_output"),
io.Image.Output("img_output", display_name="img🖼", tooltip="This is an image"),
],
hidden=[
@ -59,8 +60,8 @@ class V3TestNode(io.ComfyNodeV3):
)
@classmethod
def execute(cls, image: io.Image.Type, some_int: int, combo: io.ComboInput.Type, xyz: XYZ.Type=None, mask: io.Mask.Type=None):
some_int
def execute(cls, image: io.Image.Type, some_int: int, combo: io.Combo.Type, combo2: io.MultiCombo.Type, xyz: XYZ.Type=None, mask: io.Mask.Type=None):
#some_int
if hasattr(cls, "hahajkunless"):
raise Exception("The 'cls' variable leaked instance state between runs!")
if hasattr(cls, "doohickey"):