From aefd845a2124bd6afdbd524ef94991a01e93ff44 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 26 Jun 2025 15:41:49 -0700 Subject: [PATCH] Multitype refactor progress --- comfy_api/v3/io.py | 58 ++++++++++++++++++++++------------- comfy_extras/nodes_v3_test.py | 5 +-- 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/comfy_api/v3/io.py b/comfy_api/v3/io.py index bbdd12300..b0a4931c7 100644 --- a/comfy_api/v3/io.py +++ b/comfy_api/v3/io.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast +from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast, override from enum import Enum from abc import ABC, abstractmethod from dataclasses import dataclass, asdict @@ -178,12 +178,11 @@ class WidgetInputV3(InputV3): ''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: Any=None, - socketless: bool=None, widgetType: str=None, types: list[type[ComfyType] | ComfyType]=None, extra_dict=None): + socketless: bool=None, widgetType: str=None, extra_dict=None): super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) self.default = default self.socketless = socketless self.widgetType = widgetType - self.types = types if types is not None else [] def as_dict_V1(self): return super().as_dict_V1() | prune_dict({ @@ -193,11 +192,7 @@ class WidgetInputV3(InputV3): }) def get_io_type_V1(self): - # combine passed-in types and expected widgetType - str_types = [x.io_type for x in self.types] - str_types.insert(0, self.widgetType) - # ensure types are unique and order is preserved - return ','.join(list(dict.fromkeys(str_types))) + return self.widgetType if self.widgetType is not None else super().get_io_type_V1() class OutputV3(IO_V3): def __init__(self, id: str, display_name: str=None, tooltip: str=None, @@ -284,8 +279,8 @@ class Boolean: '''Boolean input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: bool=None, label_on: str=None, label_off: str=None, - socketless: bool=None, types: list[type[ComfyType] | ComfyType]=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, types) + socketless: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type) self.label_on = label_on self.label_off = label_off self.default: bool @@ -307,8 +302,8 @@ class Int: '''Integer input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, - display_mode: NumberDisplay=None, socketless: bool=None, types: list[type[ComfyType] | ComfyType]=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, types) + display_mode: NumberDisplay=None, socketless: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type) self.min = min self.max = max self.step = step @@ -336,8 +331,8 @@ class Float: '''Float input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, - display_mode: NumberDisplay=None, socketless: bool=None, types: list[type[ComfyType] | ComfyType]=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, types) + display_mode: NumberDisplay=None, socketless: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type) self.min = min self.max = max self.step = step @@ -365,8 +360,8 @@ class String: '''String input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, multiline=False, placeholder: str=None, default: int=None, - socketless: bool=None, types: list[type[ComfyType] | ComfyType]=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, types) + socketless: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type) self.multiline = multiline self.placeholder = placeholder self.default: str @@ -390,8 +385,8 @@ class Combo: default: str=None, control_after_generate: bool=None, image_upload: bool=None, image_folder: FolderType=None, remote: RemoteOptions=None, - socketless: bool=None, types: list[type[ComfyType] | ComfyType]=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, types) + socketless: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type) self.multiselect = False self.options = options self.control_after_generate = control_after_generate @@ -531,10 +526,19 @@ class MultiType: Type = Any class Input(InputV3): ''' - Input that permits more than one input type. + Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. ''' - def __init__(self, id: str, types: list[type[ComfyType] | ComfyType], display_name: str=None, optional=False, tooltip: str=None,): - super().__init__(id, display_name, optional, tooltip) + def __init__(self, id: str | ComfyType.Input, types: list[type[ComfyType] | ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + # if id is an Input, then use that Input with overridden values + self.input_override = None + if isinstance(id, InputV3): + self.input_override = id + optional = id.optional if id.optional is True else optional + tooltip = id.tooltip if id.tooltip is not None else tooltip + display_name = id.display_name if id.display_name is not None else display_name + lazy = id.lazy if id.lazy is not None else lazy + id = id.id + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) self._io_types = types @property @@ -552,7 +556,17 @@ class MultiType: def get_io_type_V1(self): # ensure types are unique and order is preserved - return ",".join(list(dict.fromkeys([x.io_type for x in self.io_types]))) + str_types = [x.io_type for x in self.io_types] + if self.input_override is not None: + str_types.insert(0, self.input_override.get_io_type_V1()) + return ",".join(list(dict.fromkeys(str_types))) + + @override + def as_dict_V1(self): + if self.input_override is not None: + return self.input_override.as_dict_V1() | super().as_dict_V1() + else: + return super().as_dict_V1() class DynamicInput(InputV3): ''' diff --git a/comfy_extras/nodes_v3_test.py b/comfy_extras/nodes_v3_test.py index c65926ac2..9f385f952 100644 --- a/comfy_extras/nodes_v3_test.py +++ b/comfy_extras/nodes_v3_test.py @@ -38,9 +38,10 @@ class V3TestNode(io.ComfyNodeV3): io.Custom("JKL").Input("jkl", optional=True), io.Mask.Input("mask", optional=True), io.Int.Input("some_int", display_name="new_name", min=0, max=127, default=42, - tooltip="My tooltip 😎", display_mode=io.NumberDisplay.slider, types=[io.Float]), - io.Combo.Input("combo", options=["a", "b", "c"], tooltip="This is a combo input", types=[io.Mask]), + tooltip="My tooltip 😎", display_mode=io.NumberDisplay.slider), + io.Combo.Input("combo", options=["a", "b", "c"], tooltip="This is a combo input"), io.MultiCombo.Input("combo2", options=["a","b","c"]), + io.MultiType.Input(io.Int.Input("int_multitype", display_name="haha"), types=[io.Float]), io.MultiType.Input("multitype", types=[io.Mask, io.Float, io.Int], optional=True), # ComboInput("combo", image_upload=True, image_folder=FolderType.output, # remote=RemoteOptions(