diff --git a/comfy_api/v3_01/io.py b/comfy_api/v3_01/io.py index 86a97c8e9..acc62a30d 100644 --- a/comfy_api/v3_01/io.py +++ b/comfy_api/v3_01/io.py @@ -1,22 +1,16 @@ from __future__ import annotations -from typing import Any, Literal, TYPE_CHECKING, TypeVar +from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast from enum import Enum from abc import ABC, abstractmethod from dataclasses import dataclass, asdict from comfy.comfy_types.node_typing import IO +# NOTE: these imports here are mostly for keeping execution.py happy with type inheritance from comfy_api.v3.io import ComfyNodeV3 as BASE_CV3 from comfy_api.v3.io import NodeOutput as BASE_NO -# if TYPE_CHECKING: import torch -class InputBehavior(str, Enum): - '''Likely deprecated; required/optional can be a bool, unlikely to be more categories that fit.''' - required = "required" - optional = "optional" - - class FolderType(str, Enum): input = "input" output = "output" @@ -62,23 +56,75 @@ class NumberDisplay(str, Enum): slider = "slider" +class ComfyType: + Type = Any + io_type: str = None + Input = None + Output = None + +# NOTE: this is a workaround to make the decorator return the correct type +T = TypeVar("T", bound=type) +def comfytype(io_type: str, **kwargs): + ''' + Decorator to mark nested classes as ComfyType; io_type will be bound to the class. + + A ComfyType may have the following attributes: + - Type = + - class Input(InputV3): ... + - class Output(OutputV3): ... + ''' + def decorator(cls: T) -> T: + if not isinstance(cls, ComfyType): + # copy class attributes except for special ones that shouldn't be in type() + cls_dict = { + k: v for k, v in cls.__dict__.items() + if k not in ('__dict__', '__weakref__', '__module__', '__doc__') + } + # new class + new_cls: ComfyType = type( + cls.__name__, + (cls, ComfyType), + cls_dict + ) + # metadata preservation + new_cls.__module__ = cls.__module__ + new_cls.__doc__ = cls.__doc__ + # assign ComfyType attributes, if needed + # NOTE: do we need __ne__ trick for io_type? (see IO.__ne__ for details) + else: + new_cls = cls + new_cls.io_type = io_type + if new_cls.Input is not None: + new_cls.Input.Parent = new_cls + if new_cls.Output is not None: + new_cls.Output.Parent = new_cls + return new_cls + return decorator class IO_V3: ''' Base class for V3 Inputs and Outputs. ''' - Type = Any + Parent: ComfyType = None def __init__(self): pass - def __init_subclass__(cls, io_type: IO | str, **kwargs): - # TODO: do we need __ne__ trick for io_type? (see IO.__ne__ for details) - cls.io_type = io_type - super().__init_subclass__(**kwargs) + # def __init_subclass__(cls, io_type: IO | str, **kwargs): + # # TODO: do we need __ne__ trick for io_type? (see IO.__ne__ for details) + # cls.io_type = io_type + # super().__init_subclass__(**kwargs) + + @property + def io_type(self): + return self.Parent.io_type -class InputV3(IO_V3, io_type=None): + @property + def Type(self): + return self.Parent.Type + +class InputV3(IO_V3): ''' Base class for a V3 Input. ''' @@ -100,7 +146,7 @@ class InputV3(IO_V3, io_type=None): def get_io_type_V1(self): return self.io_type -class WidgetInputV3(InputV3, io_type=None): +class WidgetInputV3(InputV3): ''' Base class for a V3 Input with widget. ''' @@ -119,7 +165,7 @@ class WidgetInputV3(InputV3, io_type=None): "widgetType": self.widgetType, }) -class OutputV3(IO_V3, io_type=None): +class OutputV3(IO_V3): def __init__(self, id: str, display_name: str=None, tooltip: str=None, is_output_list=False): self.id = id @@ -156,10 +202,11 @@ def CustomOutput(id: str, io_type: IO | str, display_name: str=None, tooltip: st return type(f"{io_type}Output", (OutputV3,), {}, io_type=io_type)(**input_kwargs) +@comfytype(io_type=IO.BOOLEAN) class Boolean: Type = bool - class Input(WidgetInputV3, io_type=IO.BOOLEAN): + class Input(WidgetInputV3): '''Boolean input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: bool=None, label_on: str=None, label_off: str=None, @@ -175,13 +222,14 @@ class Boolean: "label_off": self.label_off, }) - class Output(OutputV3, io_type=IO.BOOLEAN): + class Output(OutputV3): ... -class Integer: +@comfytype(io_type=IO.INT) +class Int: Type = int - class Input(WidgetInputV3, io_type=IO.INT): + class Input(WidgetInputV3): '''Integer input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, @@ -203,13 +251,14 @@ class Integer: "display": self.display_mode, }) - class Output(OutputV3, io_type=IO.INT): + class Output(OutputV3): ... +@comfytype(io_type=IO.FLOAT) class Float: Type = float - class Input(WidgetInputV3, io_type=IO.FLOAT): + class Input(WidgetInputV3): '''Float input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, @@ -232,13 +281,14 @@ class Float: "display": self.display_mode, }) - class Output(OutputV3, io_type=IO.FLOAT): + class Output(OutputV3): ... +@comfytype(io_type=IO.STRING) class String: Type = str - class Input(WidgetInputV3, io_type=IO.STRING): + class Input(WidgetInputV3): '''String input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, multiline=False, placeholder: str=None, default: int=None, @@ -254,306 +304,292 @@ class String: "placeholder": self.placeholder, }) - class Output(OutputV3, io_type=IO.STRING): + class Output(OutputV3): ... -class ComboInput(WidgetInputV3, io_type=IO.COMBO): - '''Combo input (dropdown).''' +@comfytype(io_type=IO.COMBO) +class Combo: Type = str - def __init__(self, id: str, options: list[str]=None, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, - default: str=None, control_after_generate: bool=None, - image_upload: bool=None, image_folder: FolderType=None, - remote: RemoteOptions=None, - socketless: bool=None, widgetType: str=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, widgetType) - self.multiselect = False - self.options = options - self.control_after_generate = control_after_generate - self.image_upload = image_upload - self.image_folder = image_folder - self.remote = remote - self.default: str - - def as_dict_V1(self): - return super().as_dict_V1() | prune_dict({ - "multiselect": self.multiselect, - "options": self.options, - "control_after_generate": self.control_after_generate, - "image_upload": self.image_upload, - "image_folder": self.image_folder.value if self.image_folder else None, - "remote": self.remote.as_dict() if self.remote else None, - }) + class Input(WidgetInputV3): + '''Combo input (dropdown).''' + Type = str + def __init__(self, id: str, options: list[str]=None, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: str=None, control_after_generate: bool=None, + image_upload: bool=None, image_folder: FolderType=None, + remote: RemoteOptions=None, + socketless: bool=None, widgetType: str=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, widgetType) + self.multiselect = False + self.options = options + self.control_after_generate = control_after_generate + self.image_upload = image_upload + self.image_folder = image_folder + self.remote = remote + self.default: str + + def as_dict_V1(self): + return super().as_dict_V1() | prune_dict({ + "multiselect": self.multiselect, + "options": self.options, + "control_after_generate": self.control_after_generate, + "image_upload": self.image_upload, + "image_folder": self.image_folder.value if self.image_folder else None, + "remote": self.remote.as_dict() if self.remote else None, + }) -class MultiselectComboWidget(ComboInput, io_type=IO.COMBO): +@comfytype(io_type=IO.COMBO) +class MultiCombo: '''Multiselect Combo input (dropdown for selecting potentially more than one value).''' - def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, - default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, - socketless: bool=None, widgetType: str=None): - super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless, widgetType) - self.multiselect = True - self.placeholder = placeholder - self.chip = chip - self.default: list[str] - - def as_dict_V1(self): - return super().as_dict_V1() | prune_dict({ - "multiselect": self.multiselect, - "placeholder": self.placeholder, - "chip": self.chip, - }) + Type = list[str] + class Input(Combo.Input): + def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, + socketless: bool=None, widgetType: str=None): + super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless, widgetType) + self.multiselect = True + self.placeholder = placeholder + self.chip = chip + self.default: list[str] + + def as_dict_V1(self): + return super().as_dict_V1() | prune_dict({ + "multiselect": self.multiselect, + "placeholder": self.placeholder, + "chip": self.chip, + }) + +@comfytype(io_type=IO.IMAGE) class Image: Type = torch.Tensor - - class Input(InputV3, io_type=IO.IMAGE): + class Input(InputV3): '''Image input.''' - - class Output(OutputV3, io_type=IO.IMAGE): + class Output(OutputV3): ... +@comfytype(io_type=IO.MASK) class Mask: Type = torch.Tensor - - class Input(InputV3, io_type=IO.MASK): + class Input(InputV3): '''Mask input.''' - - class Output(OutputV3, io_type=IO.MASK): + class Output(OutputV3): ... +@comfytype(io_type=IO.LATENT) class Latent: Type = Any # TODO: make Type a TypedDict - - class Input(InputV3, io_type=IO.LATENT): + class Input(InputV3): '''Latent input.''' - - class Output(OutputV3, io_type=IO.LATENT): + class Output(OutputV3): ... +@comfytype(io_type=IO.CONDITIONING) class Conditioning: Type = Any - - class Input(InputV3, io_type=IO.CONDITIONING): + class Input(InputV3): '''Conditioning input.''' - - class Output(OutputV3, io_type=IO.CONDITIONING): + class Output(OutputV3): ... +@comfytype(io_type=IO.SAMPLER) class Sampler: Type = Any - - class Input(InputV3, io_type=IO.SAMPLER): + class Input(InputV3): '''Sampler input.''' - - class Output(OutputV3, io_type=IO.SAMPLER): + class Output(OutputV3): ... +@comfytype(io_type=IO.SIGMAS) class Sigmas: Type = Any - - class Input(InputV3, io_type=IO.SIGMAS): + class Input(InputV3): '''Sigmas input.''' - - class Output(OutputV3, io_type=IO.SIGMAS): + class Output(OutputV3): ... +@comfytype(io_type=IO.NOISE) class Noise: Type = Any - - class Input(InputV3, io_type=IO.NOISE): + class Input(InputV3): '''Noise input.''' - - class Output(OutputV3, io_type=IO.NOISE): + class Output(OutputV3): ... +@comfytype(io_type=IO.GUIDER) class Guider: Type = Any - - class Input(InputV3, io_type=IO.GUIDER): + class Input(InputV3): '''Guider input.''' - - class Output(OutputV3, io_type=IO.GUIDER): + class Output(OutputV3): ... +@comfytype(io_type=IO.CLIP) class Clip: Type = Any - - class Input(InputV3, io_type=IO.CLIP): + class Input(InputV3): '''Clip input.''' - - class Output(OutputV3, io_type=IO.CLIP): + class Output(OutputV3): ... +@comfytype(io_type=IO.CONTROL_NET) class ControlNet: Type = Any - - class Input(InputV3, io_type=IO.CONTROL_NET): + class Input(InputV3): '''ControlNet input.''' - - class Output(OutputV3, io_type=IO.CONTROL_NET): + class Output(OutputV3): ... +@comfytype(io_type=IO.VAE) class Vae: Type = Any - - class Input(InputV3, io_type=IO.VAE): + class Input(InputV3): '''Vae input.''' - - class Output(OutputV3, io_type=IO.VAE): + class Output(OutputV3): ... +@comfytype(io_type=IO.MODEL) class Model: Type = Any - - class Input(InputV3, io_type=IO.MODEL): + class Input(InputV3): '''Model input.''' - - class Output(OutputV3, io_type=IO.MODEL): + class Output(OutputV3): ... +@comfytype(io_type=IO.CLIP_VISION) class ClipVision: Type = Any - - class Input(InputV3, io_type=IO.CLIP_VISION): + class Input(InputV3): '''ClipVision input.''' - - class Output(OutputV3, io_type=IO.CLIP_VISION): + class Output(OutputV3): ... +@comfytype(io_type=IO.CLIP_VISION_OUTPUT) class ClipVisionOutput: Type = Any - - class Input(InputV3, io_type=IO.CLIP_VISION_OUTPUT): - '''CLipVisionOutput input.''' - - class Output(OutputV3, io_type=IO.CLIP_VISION_OUTPUT): + class Input(InputV3): + '''ClipVisionOutput input.''' + class Output(OutputV3): ... +@comfytype(io_type=IO.STYLE_MODEL) class StyleModel: Type = Any - - class Input(InputV3, io_type=IO.STYLE_MODEL): + class Input(InputV3): '''StyleModel input.''' - - class Output(OutputV3, io_type=IO.STYLE_MODEL): + class Output(OutputV3): ... +@comfytype(io_type=IO.GLIGEN) class Gligen: Type = Any - - class Input(InputV3, io_type=IO.GLIGEN): + class Input(InputV3): '''Gligen input.''' - - class Output(OutputV3, io_type=IO.GLIGEN): + class Output(OutputV3): ... +@comfytype(io_type=IO.UPSCALE_MODEL) class UpscaleModel: Type = Any - - class Input(InputV3, io_type=IO.UPSCALE_MODEL): + class Input(InputV3): '''UpscaleModel input.''' - - class Output(OutputV3, io_type=IO.UPSCALE_MODEL): + class Output(OutputV3): ... +@comfytype(io_type=IO.AUDIO) class Audio: Type = Any - - class Input(InputV3, io_type=IO.AUDIO): + class Input(InputV3): '''Audio input.''' - - class Output(OutputV3, io_type=IO.AUDIO): + class Output(OutputV3): ... +@comfytype(io_type=IO.POINT) class Point: Type = Any - - class Input(InputV3, io_type=IO.POINT): + class Input(InputV3): '''Point input.''' - - class Output(OutputV3, io_type=IO.POINT): + class Output(OutputV3): ... +@comfytype(io_type=IO.FACE_ANALYSIS) class FaceAnalysis: Type = Any - - class Input(InputV3, io_type=IO.FACE_ANALYSIS): + class Input(InputV3): '''FaceAnalysis input.''' - - class Output(OutputV3, io_type=IO.FACE_ANALYSIS): + class Output(OutputV3): ... +@comfytype(io_type=IO.BBOX) class BBOX: Type = Any - - class Input(InputV3, io_type=IO.BBOX): + class Input(InputV3): '''Bbox input.''' - - class Output(OutputV3, io_type=IO.BBOX): + class Output(OutputV3): ... +@comfytype(io_type=IO.SEGS) class SEGS: Type = Any - - class Input(InputV3, io_type=IO.SEGS): + class Input(InputV3): '''SEGS input.''' - - class Output(OutputV3, io_type=IO.SEGS): + class Output(OutputV3): ... +@comfytype(io_type=IO.VIDEO) class Video: Type = Any - - class Input(InputV3, io_type=IO.VIDEO): + class Input(InputV3): '''Video input.''' - - class Output(OutputV3, io_type=IO.VIDEO): + class Output(OutputV3): ... - -class MultitypedInput(InputV3, io_type="COMFY_MULTITYPED_V3"): - ''' - Input that permits more than one input type. - ''' - def __init__(self, id: str, io_types: list[type[IO_V3] | InputV3 | IO |str], display_name: str=None, optional=False, tooltip: str=None,): - super().__init__(id, display_name, optional, tooltip) - self._io_types = io_types - - @property - def io_types(self) -> list[type[InputV3]]: +@comfytype(io_type="COMFY_MULTITYPED_V3") +class MultiType: + Type = Any + class Input(InputV3): ''' - Returns list of InputV3 class types permitted. + Input that permits more than one input type. ''' - io_types = [] - for x in self._io_types: - if not is_class(x): - io_types.append(type(x)) - else: - io_types.append(x) - return io_types - - def get_io_type_V1(self): - return ",".join(x.io_type for x in self.io_types) + def __init__(self, id: str, io_types: list[type[ComfyType] | ComfyType | IO |str], display_name: str=None, optional=False, tooltip: str=None,): + super().__init__(id, display_name, optional, tooltip) + self._io_types = io_types + + @property + def io_types(self) -> list[type[InputV3]]: + ''' + Returns list of InputV3 class types permitted. + ''' + io_types = [] + for x in self._io_types: + if not is_class(x): + io_types.append(type(x)) + else: + io_types.append(x) + return io_types + + def get_io_type_V1(self): + return ",".join(x.io_type for x in self.io_types) - -class DynamicInput(InputV3, io_type=None): +class DynamicInput(InputV3): ''' Abstract class for dynamic input registration. ''' def __init__(self, io_type: str, id: str, display_name: str=None): super().__init__(io_type, id, display_name) -class DynamicOutput(OutputV3, io_type=None): +class DynamicOutput(OutputV3): ''' Abstract class for dynamic output registration. ''' def __init__(self, io_type: str, id: str, display_name: str=None): super().__init__(io_type, id, display_name) -class AutoGrowDynamicInput(DynamicInput, io_type="COMFY_MULTIGROW_V3"): +# io_type="COMFY_MULTIGROW_V3" +class AutoGrowDynamicInput(DynamicInput): ''' Dynamic Input that adds another template_input each time one is provided. - Additional inputs are forced to have 'InputBehavior.optional'. + Additional inputs are forced to have 'optional=True'. ''' def __init__(self, id: str, template_input: InputV3, min: int=1, max: int=None): super().__init__("AutoGrowDynamicInput", id) @@ -565,7 +601,8 @@ class AutoGrowDynamicInput(DynamicInput, io_type="COMFY_MULTIGROW_V3"): self.min = min self.max = max -class ComboDynamicInput(DynamicInput, io_type="COMFY_COMBODYNAMIC_V3"): +# io_type="COMFY_COMBODYNAMIC_V3" +class ComboDynamicInput(DynamicInput): def __init__(self, id: str): pass @@ -576,7 +613,6 @@ class Hidden(str, Enum): ''' Enumerator for requesting hidden variables in nodes. ''' - unique_id = "UNIQUE_ID" """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" prompt = "PROMPT" @@ -979,14 +1015,6 @@ class ComfyNodeV3(BASE_CV3): #-------------------------------------------- ############################################# -# class ReturnedInputs: -# def __init__(self): -# pass - -# class ReturnedOutputs: -# def __init__(self): -# pass - class NodeOutput(BASE_NO): ''' @@ -1074,6 +1102,11 @@ class UIText(UIOutput): return {"text": (self.value,)} +def create_image_preview(image: Image.Type) -> UIImages: + # TODO: finish, right now is just Cursor's hallucination + return UIImages([SavedResult("preview.png", "comfy_org", FolderType.output)]) + + class TestNode(ComfyNodeV3): @classmethod def DEFINE_SCHEMA(cls): @@ -1081,7 +1114,7 @@ class TestNode(ComfyNodeV3): node_id="TestNode_v3", display_name="Test Node (V3)", category="v3_test", - inputs=[Integer.Input("my_int"), + inputs=[Int.Input("my_int"), #AutoGrowDynamicInput("growing", Image.Input), Mask.Input("thing"), ], @@ -1096,8 +1129,8 @@ class TestNode(ComfyNodeV3): if __name__ == "__main__": print("hello there") inputs: list[InputV3] = [ - Integer.Input("tessfes", widgetType=IO.STRING), - Integer.Input("my_int"), + Int.Input("tessfes", widgetType=IO.STRING), + Int.Input("my_int"), CustomInput("xyz", "XYZ"), CustomInput("model1", "MODEL_M"), Image.Input("my_image"), diff --git a/comfy_extras/nodes_v1_test.py b/comfy_extras/nodes_v1_test.py index 5ef31a3b7..ea9884856 100644 --- a/comfy_extras/nodes_v1_test.py +++ b/comfy_extras/nodes_v1_test.py @@ -9,13 +9,13 @@ class TestNode(ComfyNodeABC): return { "required": { "image": (IO.IMAGE,), - "xyz": ("XYZ",), "some_int": (IO.INT, {"display_name": "new_name", "min": 0, "max": 127, "default": 42, "tooltip": "My tooltip 😎", "display": "slider"}), "combo": (IO.COMBO, {"options": ["a", "b", "c"], "tooltip": "This is a combo input"}), }, "optional": { + "xyz": ("XYZ",), "mask": (IO.MASK,), } } @@ -29,7 +29,7 @@ class TestNode(ComfyNodeABC): CATEGORY = "v3 nodes" - def do_thing(self, image: torch.Tensor, xyz, some_int: int, combo: str, mask: torch.Tensor=None): + def do_thing(self, image: torch.Tensor, some_int: int, combo: str, xyz=None, mask: torch.Tensor=None): return (some_int, image) diff --git a/comfy_extras/nodes_v3_01_test.py b/comfy_extras/nodes_v3_01_test.py index 774ca5d26..13031392b 100644 --- a/comfy_extras/nodes_v3_01_test.py +++ b/comfy_extras/nodes_v3_01_test.py @@ -3,14 +3,14 @@ from comfy_api.v3_01 import io import logging +@io.comfytype(io_type="XYZ") class XYZ: Type = tuple[int,str] - class Input(io.InputV3, io_type="XYZ"): + class Input(io.InputV3): ... - class Output(io.OutputV3, io_type="XYZ"): + class Output(io.OutputV3): ... - class V3TestNode(io.ComfyNodeV3): def __init__(self): @@ -28,9 +28,10 @@ class V3TestNode(io.ComfyNodeV3): XYZ.Input("xyz", optional=True), #CustomInput("xyz", "XYZ", optional=True), io.Mask.Input("mask", optional=True), - io.Integer.Input("some_int", display_name="new_name", min=0, max=127, default=42, + io.Int.Input("some_int", display_name="new_name", min=0, max=127, default=42, tooltip="My tooltip 😎", display_mode=io.NumberDisplay.slider), - io.ComboInput("combo", options=["a", "b", "c"], tooltip="This is a combo input"), + io.Combo.Input("combo", options=["a", "b", "c"], tooltip="This is a combo input"), + io.MultiCombo.Input("combo2", options=["a","b","c"]), # ComboInput("combo", image_upload=True, image_folder=FolderType.output, # remote=RemoteOptions( # route="/internal/files/output", @@ -49,7 +50,7 @@ class V3TestNode(io.ComfyNodeV3): # ]] ], outputs=[ - io.Integer.Output("int_output"), + io.Int.Output("int_output"), io.Image.Output("img_output", display_name="img🖼️", tooltip="This is an image"), ], hidden=[ @@ -59,8 +60,8 @@ class V3TestNode(io.ComfyNodeV3): ) @classmethod - def execute(cls, image: io.Image.Type, some_int: int, combo: io.ComboInput.Type, xyz: XYZ.Type=None, mask: io.Mask.Type=None): - some_int + def execute(cls, image: io.Image.Type, some_int: int, combo: io.Combo.Type, combo2: io.MultiCombo.Type, xyz: XYZ.Type=None, mask: io.Mask.Type=None): + #some_int if hasattr(cls, "hahajkunless"): raise Exception("The 'cls' variable leaked instance state between runs!") if hasattr(cls, "doohickey"):