Multitype refactor progress

This commit is contained in:
Jedrzej Kosinski 2025-06-26 15:41:49 -07:00
parent 6ef4ad2a4c
commit aefd845a21
2 changed files with 39 additions and 24 deletions

View File

@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast
from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast, override
from enum import Enum
from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
@ -178,12 +178,11 @@ class WidgetInputV3(InputV3):
'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: Any=None,
socketless: bool=None, widgetType: str=None, types: list[type[ComfyType] | ComfyType]=None, extra_dict=None):
socketless: bool=None, widgetType: str=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.default = default
self.socketless = socketless
self.widgetType = widgetType
self.types = types if types is not None else []
def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({
@ -193,11 +192,7 @@ class WidgetInputV3(InputV3):
})
def get_io_type_V1(self):
# combine passed-in types and expected widgetType
str_types = [x.io_type for x in self.types]
str_types.insert(0, self.widgetType)
# ensure types are unique and order is preserved
return ','.join(list(dict.fromkeys(str_types)))
return self.widgetType if self.widgetType is not None else super().get_io_type_V1()
class OutputV3(IO_V3):
def __init__(self, id: str, display_name: str=None, tooltip: str=None,
@ -284,8 +279,8 @@ class Boolean:
'''Boolean input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: bool=None, label_on: str=None, label_off: str=None,
socketless: bool=None, types: list[type[ComfyType] | ComfyType]=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, types)
socketless: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type)
self.label_on = label_on
self.label_off = label_off
self.default: bool
@ -307,8 +302,8 @@ class Int:
'''Integer input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
display_mode: NumberDisplay=None, socketless: bool=None, types: list[type[ComfyType] | ComfyType]=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, types)
display_mode: NumberDisplay=None, socketless: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type)
self.min = min
self.max = max
self.step = step
@ -336,8 +331,8 @@ class Float:
'''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, socketless: bool=None, types: list[type[ComfyType] | ComfyType]=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, types)
display_mode: NumberDisplay=None, socketless: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type)
self.min = min
self.max = max
self.step = step
@ -365,8 +360,8 @@ class String:
'''String input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
multiline=False, placeholder: str=None, default: int=None,
socketless: bool=None, types: list[type[ComfyType] | ComfyType]=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, types)
socketless: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type)
self.multiline = multiline
self.placeholder = placeholder
self.default: str
@ -390,8 +385,8 @@ class Combo:
default: str=None, control_after_generate: bool=None,
image_upload: bool=None, image_folder: FolderType=None,
remote: RemoteOptions=None,
socketless: bool=None, types: list[type[ComfyType] | ComfyType]=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, types)
socketless: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type)
self.multiselect = False
self.options = options
self.control_after_generate = control_after_generate
@ -531,10 +526,19 @@ class MultiType:
Type = Any
class Input(InputV3):
'''
Input that permits more than one input type.
Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values.
'''
def __init__(self, id: str, types: list[type[ComfyType] | ComfyType], display_name: str=None, optional=False, tooltip: str=None,):
super().__init__(id, display_name, optional, tooltip)
def __init__(self, id: str | ComfyType.Input, types: list[type[ComfyType] | ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
# if id is an Input, then use that Input with overridden values
self.input_override = None
if isinstance(id, InputV3):
self.input_override = id
optional = id.optional if id.optional is True else optional
tooltip = id.tooltip if id.tooltip is not None else tooltip
display_name = id.display_name if id.display_name is not None else display_name
lazy = id.lazy if id.lazy is not None else lazy
id = id.id
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self._io_types = types
@property
@ -552,7 +556,17 @@ class MultiType:
def get_io_type_V1(self):
# ensure types are unique and order is preserved
return ",".join(list(dict.fromkeys([x.io_type for x in self.io_types])))
str_types = [x.io_type for x in self.io_types]
if self.input_override is not None:
str_types.insert(0, self.input_override.get_io_type_V1())
return ",".join(list(dict.fromkeys(str_types)))
@override
def as_dict_V1(self):
if self.input_override is not None:
return self.input_override.as_dict_V1() | super().as_dict_V1()
else:
return super().as_dict_V1()
class DynamicInput(InputV3):
'''

View File

@ -38,9 +38,10 @@ class V3TestNode(io.ComfyNodeV3):
io.Custom("JKL").Input("jkl", optional=True),
io.Mask.Input("mask", optional=True),
io.Int.Input("some_int", display_name="new_name", min=0, max=127, default=42,
tooltip="My tooltip 😎", display_mode=io.NumberDisplay.slider, types=[io.Float]),
io.Combo.Input("combo", options=["a", "b", "c"], tooltip="This is a combo input", types=[io.Mask]),
tooltip="My tooltip 😎", display_mode=io.NumberDisplay.slider),
io.Combo.Input("combo", options=["a", "b", "c"], tooltip="This is a combo input"),
io.MultiCombo.Input("combo2", options=["a","b","c"]),
io.MultiType.Input(io.Int.Input("int_multitype", display_name="haha"), types=[io.Float]),
io.MultiType.Input("multitype", types=[io.Mask, io.Float, io.Int], optional=True),
# ComboInput("combo", image_upload=True, image_folder=FolderType.output,
# remote=RemoteOptions(