Try out adding Type class var to IO_V3 to help with type hints

This commit is contained in:
kosinkadink1@gmail.com
2025-06-10 00:19:17 -07:00
parent 2197b6cbf3
commit 70d2bbfec0
2 changed files with 26 additions and 5 deletions

View File

@@ -1,10 +1,13 @@
from __future__ import annotations
from typing import Any, Literal
from typing import Any, Literal, TYPE_CHECKING, TypeVar
from enum import Enum
from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
from comfy.comfy_types.node_typing import IO
# if TYPE_CHECKING:
import torch
class InputBehavior(str, Enum):
required = "required"
@@ -60,11 +63,14 @@ class IO_V3:
'''
Base class for V3 Inputs and Outputs.
'''
Type = Any
def __init__(self):
pass
def __init_subclass__(cls, io_type: IO | str, **kwargs):
def __init_subclass__(cls, io_type: IO | str, Type=Any, **kwargs):
cls.io_type = io_type
cls.Type = Type
super().__init_subclass__(**kwargs)
class InputV3(IO_V3, io_type=None):
@@ -141,6 +147,7 @@ class BooleanInput(WidgetInputV3, io_type=IO.BOOLEAN):
'''
Boolean input.
'''
Type = bool
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: bool=None, label_on: str=None, label_off: str=None,
socketless: bool=None, widgetType: str=None):
@@ -159,6 +166,7 @@ class IntegerInput(WidgetInputV3, io_type=IO.INT):
'''
Integer input.
'''
Type = int
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None):
@@ -183,6 +191,7 @@ class FloatInput(WidgetInputV3, io_type=IO.FLOAT):
'''
Float input.
'''
Type = float
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, socketless: bool=None, widgetType: str=None):
@@ -208,6 +217,7 @@ class StringInput(WidgetInputV3, io_type=IO.STRING):
'''
String input.
'''
Type = str
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
multiline=False, placeholder: str=None, default: int=None,
socketless: bool=None, widgetType: str=None):
@@ -224,6 +234,7 @@ class StringInput(WidgetInputV3, io_type=IO.STRING):
class ComboInput(WidgetInputV3, io_type=IO.COMBO):
'''Combo input (dropdown).'''
Type = str
def __init__(self, id: str, options: list[str]=None, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None, lazy: bool=None,
default: str=None, control_after_generate: bool=None,
image_upload: bool=None, image_folder: FolderType=None,
@@ -270,6 +281,7 @@ class ImageInput(InputV3, io_type=IO.IMAGE):
'''
Image input.
'''
Type = torch.Tensor
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
super().__init__(id, display_name, behavior, tooltip)
@@ -277,6 +289,7 @@ class MaskInput(InputV3, io_type=IO.MASK):
'''
Mask input.
'''
Type = torch.Tensor
def __init__(self, id: str, display_name: str=None, behavior=InputBehavior.required, tooltip: str=None):
super().__init__(id, display_name, behavior, tooltip)