diff --git a/comfy_api/v3/io.py b/comfy_api/v3/io.py index 957bae802..b0a4931c7 100644 --- a/comfy_api/v3/io.py +++ b/comfy_api/v3/io.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast +from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast, override from enum import Enum from abc import ABC, abstractmethod from dataclasses import dataclass, asdict @@ -178,12 +178,11 @@ class WidgetInputV3(InputV3): ''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: Any=None, - socketless: bool=None, widgetType: str=None, types: list[type[ComfyType] | ComfyType]=None, extra_dict=None): + socketless: bool=None, widgetType: str=None, extra_dict=None): super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) self.default = default self.socketless = socketless self.widgetType = widgetType - self.types = types if types is not None else [] def as_dict_V1(self): return super().as_dict_V1() | prune_dict({ @@ -193,11 +192,7 @@ class WidgetInputV3(InputV3): }) def get_io_type_V1(self): - # combine passed-in types and expected widgetType - str_types = [x.io_type for x in self.types] - str_types.insert(0, self.widgetType) - # ensure types are unique and order is preserved - return ','.join(list(dict.fromkeys(str_types))) + return self.widgetType if self.widgetType is not None else super().get_io_type_V1() class OutputV3(IO_V3): def __init__(self, id: str, display_name: str=None, tooltip: str=None, @@ -220,15 +215,40 @@ class NodeState(ABC): def __init__(self, node_id: str): self.node_id = node_id + @abstractmethod + def get_value(self, key: str): + pass + + @abstractmethod + def set_value(self, key: str, value: Any): + pass + @abstractmethod def pop(self, key: str): pass + @abstractmethod + def __contains__(self, key: str): + pass + + class NodeStateLocal(NodeState): def __init__(self, node_id: str): super().__init__(node_id) self.local_state = {} + def get_value(self, key: str): + return self.local_state.get(key) + + def set_value(self, key: str, value: Any): + self.local_state[key] = value + + def pop(self, key: str): + return self.local_state.pop(key, None) + + def __contains__(self, key: str): + return key in self.local_state + def __getattr__(self, key: str): local_state = type(self).__getattribute__(self, "local_state") if key in local_state: @@ -248,15 +268,9 @@ class NodeStateLocal(NodeState): def __getitem__(self, key: str): return self.local_state[key] - def __contains__(self, key: str): - return key in self.local_state - def __delitem__(self, key: str): del self.local_state[key] - def pop(self, key: str): - return self.local_state.pop(key) - @comfytype(io_type=IO.BOOLEAN) class Boolean: Type = bool @@ -265,8 +279,8 @@ class Boolean: '''Boolean input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: bool=None, label_on: str=None, label_off: str=None, - socketless: bool=None, types: list[type[ComfyType] | ComfyType]=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, types) + socketless: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type) self.label_on = label_on self.label_off = label_off self.default: bool @@ -288,8 +302,8 @@ class Int: '''Integer input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, - display_mode: NumberDisplay=None, socketless: bool=None, types: list[type[ComfyType] | ComfyType]=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, types) + display_mode: NumberDisplay=None, socketless: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type) self.min = min self.max = max self.step = step @@ -317,9 +331,8 @@ class Float: '''Float input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, - display_mode: NumberDisplay=None, socketless: bool=None, types: list[type[ComfyType] | ComfyType]=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, types) - self.default = default + display_mode: NumberDisplay=None, socketless: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type) self.min = min self.max = max self.step = step @@ -347,8 +360,8 @@ class String: '''String input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, multiline=False, placeholder: str=None, default: int=None, - socketless: bool=None, types: list[type[ComfyType] | ComfyType]=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, types) + socketless: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type) self.multiline = multiline self.placeholder = placeholder self.default: str @@ -372,8 +385,8 @@ class Combo: default: str=None, control_after_generate: bool=None, image_upload: bool=None, image_folder: FolderType=None, remote: RemoteOptions=None, - socketless: bool=None, types: list[type[ComfyType] | ComfyType]=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, types) + socketless: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type) self.multiselect = False self.options = options self.control_after_generate = control_after_generate @@ -513,10 +526,19 @@ class MultiType: Type = Any class Input(InputV3): ''' - Input that permits more than one input type. + Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. ''' - def __init__(self, id: str, types: list[type[ComfyType] | ComfyType], display_name: str=None, optional=False, tooltip: str=None,): - super().__init__(id, display_name, optional, tooltip) + def __init__(self, id: str | ComfyType.Input, types: list[type[ComfyType] | ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + # if id is an Input, then use that Input with overridden values + self.input_override = None + if isinstance(id, InputV3): + self.input_override = id + optional = id.optional if id.optional is True else optional + tooltip = id.tooltip if id.tooltip is not None else tooltip + display_name = id.display_name if id.display_name is not None else display_name + lazy = id.lazy if id.lazy is not None else lazy + id = id.id + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) self._io_types = types @property @@ -534,7 +556,17 @@ class MultiType: def get_io_type_V1(self): # ensure types are unique and order is preserved - return ",".join(list(dict.fromkeys([x.io_type for x in self.io_types]))) + str_types = [x.io_type for x in self.io_types] + if self.input_override is not None: + str_types.insert(0, self.input_override.get_io_type_V1()) + return ",".join(list(dict.fromkeys(str_types))) + + @override + def as_dict_V1(self): + if self.input_override is not None: + return self.input_override.as_dict_V1() | super().as_dict_V1() + else: + return super().as_dict_V1() class DynamicInput(InputV3): ''' diff --git a/comfy_extras/nodes_v3_test.py b/comfy_extras/nodes_v3_test.py index d6f9fcc76..9f385f952 100644 --- a/comfy_extras/nodes_v3_test.py +++ b/comfy_extras/nodes_v3_test.py @@ -38,9 +38,10 @@ class V3TestNode(io.ComfyNodeV3): io.Custom("JKL").Input("jkl", optional=True), io.Mask.Input("mask", optional=True), io.Int.Input("some_int", display_name="new_name", min=0, max=127, default=42, - tooltip="My tooltip 😎", display_mode=io.NumberDisplay.slider, types=[io.Float]), - io.Combo.Input("combo", options=["a", "b", "c"], tooltip="This is a combo input", types=[io.Mask]), + tooltip="My tooltip 😎", display_mode=io.NumberDisplay.slider), + io.Combo.Input("combo", options=["a", "b", "c"], tooltip="This is a combo input"), io.MultiCombo.Input("combo2", options=["a","b","c"]), + io.MultiType.Input(io.Int.Input("int_multitype", display_name="haha"), types=[io.Float]), io.MultiType.Input("multitype", types=[io.Mask, io.Float, io.Int], optional=True), # ComboInput("combo", image_upload=True, image_folder=FolderType.output, # remote=RemoteOptions( @@ -80,11 +81,15 @@ class V3TestNode(io.ComfyNodeV3): cls.state["thing"] = "hahaha" yyy = cls.state["thing"] del cls.state["thing"] + if cls.state.get_value("int2") is None: + cls.state.set_value("int2", 123) + zzz = cls.state.get_value("int2") + cls.state.pop("int2") if cls.state.my_int is None: cls.state.my_int = expected_int else: if cls.state.my_int != expected_int: - raise Exception(f"Explicit state object did not maintain expected value: {cls.state.my_int} != {expected_int}") + raise Exception(f"Explicit state object did not maintain expected value (__getattr__/__setattr__): {cls.state.my_int} != {expected_int}") #some_int if hasattr(cls, "hahajkunless"): raise Exception("The 'cls' variable leaked instance state between runs!") @@ -158,7 +163,7 @@ class V3LoraLoader(io.ComfyNodeV3): return io.NodeOutput(model_lora, clip_lora) -NODES_LIST: list[io.ComfyNodeV3] = [ +NODES_LIST: list[type[io.ComfyNodeV3]] = [ V3TestNode, V3LoraLoader, ]