diff --git a/comfy_api/v3/io.py b/comfy_api/v3/io.py index b8867f3d0..825d59538 100644 --- a/comfy_api/v3/io.py +++ b/comfy_api/v3/io.py @@ -1,10 +1,13 @@ from __future__ import annotations -from typing import Any, Literal +from typing import Any, Literal, TYPE_CHECKING, TypeVar from enum import Enum from abc import ABC, abstractmethod from dataclasses import dataclass, asdict from comfy.comfy_types.node_typing import IO +# if TYPE_CHECKING: +import torch + class InputBehavior(str, Enum): required = "required" @@ -60,11 +63,14 @@ class IO_V3: ''' Base class for V3 Inputs and Outputs. ''' + Type = Any + def __init__(self): pass - def __init_subclass__(cls, io_type: IO | str, **kwargs): + def __init_subclass__(cls, io_type: IO | str, Type=Any, **kwargs): cls.io_type = io_type + cls.Type = Type super().__init_subclass__(**kwargs) class InputV3(IO_V3, io_type=None): @@ -141,6 +147,7 @@ class BooleanInput(WidgetInputV3, io_type=IO.BOOLEAN): ''' Boolean input. ''' + Type = bool def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None, default: bool=None, label_on: str=None, label_off: str=None, socketless: bool=None, widgetType: str=None): @@ -159,6 +166,7 @@ class IntegerInput(WidgetInputV3, io_type=IO.INT): ''' Integer input. ''' + Type = int def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None): @@ -183,6 +191,7 @@ class FloatInput(WidgetInputV3, io_type=IO.FLOAT): ''' Float input. ''' + Type = float def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None): @@ -208,6 +217,7 @@ class StringInput(WidgetInputV3, io_type=IO.STRING): ''' String input. ''' + Type = str def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None, multiline=False, placeholder: str=None, default: int=None, socketless: bool=None, widgetType: str=None): @@ -224,6 +234,7 @@ class StringInput(WidgetInputV3, io_type=IO.STRING): class ComboInput(WidgetInputV3, io_type=IO.COMBO): '''Combo input (dropdown).''' + Type = str def __init__(self, id: str, options: list[str]=None, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None, default: str=None, control_after_generate: bool=None, image_upload: bool=None, image_folder: FolderType=None, @@ -270,6 +281,7 @@ class ImageInput(InputV3, io_type=IO.IMAGE): ''' Image input. ''' + Type = torch.Tensor def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None): super().__init__(id, display_name, behavior, tooltip) @@ -277,6 +289,7 @@ class MaskInput(InputV3, io_type=IO.MASK): ''' Mask input. ''' + Type = torch.Tensor def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None): super().__init__(id, display_name, behavior, tooltip) diff --git a/comfy_extras/nodes_v3_test.py b/comfy_extras/nodes_v3_test.py index 5e2109355..49bef58d4 100644 --- a/comfy_extras/nodes_v3_test.py +++ b/comfy_extras/nodes_v3_test.py @@ -2,12 +2,19 @@ import torch from comfy_api.v3.io import ( ComfyNodeV3, SchemaV3, InputBehavior, NumberDisplay, IntegerInput, MaskInput, ImageInput, ComboInput, CustomInput, StringInput, CustomType, - IntegerOutput, ImageOutput, MultitypedInput, + IntegerOutput, ImageOutput, MultitypedInput, InputV3, OutputV3, NodeOutput, Hidden ) import logging +class XYZInput(InputV3, io_type="XYZ"): + Type = tuple[int,str] + +class XYZOutput(OutputV3, io_type="XYZ"): + ... + + class V3TestNode(ComfyNodeV3): def __init__(self): @@ -22,7 +29,8 @@ class V3TestNode(ComfyNodeV3): category="v3 nodes", inputs=[ ImageInput("image", display_name="new_image"), - CustomInput("xyz", "XYZ", behavior=InputBehavior.optional), + XYZInput("xyz", behavior=InputBehavior.optional), + #CustomInput("xyz", "XYZ", behavior=InputBehavior.optional), MaskInput("mask", behavior=InputBehavior.optional), IntegerInput("some_int", display_name="new_name", min=0, max=127, default=42, tooltip="My tooltip 😎", display_mode=NumberDisplay.slider), @@ -55,7 +63,7 @@ class V3TestNode(ComfyNodeV3): ) @classmethod - def execute(cls, image: torch.Tensor, some_int: int, combo: str, xyz=None, mask: torch.Tensor=None): + def execute(cls, image: ImageInput.Type, some_int: IntegerInput.Type, combo: ComboInput.Type, xyz: XYZInput.Type=None, mask: MaskInput.Type=None): if hasattr(cls, "hahajkunless"): raise Exception("The 'cls' variable leaked instance state between runs!") if hasattr(cls, "doohickey"):