Mock AutogrowDynamic type

This commit is contained in:
Jedrzej Kosinski 2025-07-04 16:27:03 -05:00
parent 3758c65107
commit 18a7207ca4
4 changed files with 107 additions and 29 deletions

View File

@ -1,10 +1,12 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast, TypedDict, NotRequired from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast, TypedDict
from typing_extensions import NotRequired
from enum import Enum from enum import Enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from collections import Counter from collections import Counter
from comfy_api.v3.resources import Resources, ResourcesLocal from comfy_api.v3.resources import Resources, ResourcesLocal
import copy
# used for type hinting # used for type hinting
import torch import torch
from spandrel import ImageModelDescriptor from spandrel import ImageModelDescriptor
@ -189,17 +191,19 @@ class WidgetInputV3(InputV3):
''' '''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: Any=None, default: Any=None,
socketless: bool=None, widgetType: str=None, extra_dict=None): socketless: bool=None, widgetType: str=None, extra_dict=None, force_input: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.default = default self.default = default
self.socketless = socketless self.socketless = socketless
self.widgetType = widgetType self.widgetType = widgetType
self.force_input = force_input
def as_dict_V1(self): def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({ return super().as_dict_V1() | prune_dict({
"default": self.default, "default": self.default,
"socketless": self.socketless, "socketless": self.socketless,
"widgetType": self.widgetType, "widgetType": self.widgetType,
"forceInput": self.force_input,
}) })
def get_io_type_V1(self): def get_io_type_V1(self):
@ -754,44 +758,69 @@ class MultiType:
else: else:
return super().as_dict_V1() return super().as_dict_V1()
class DynamicInput(InputV3): class DynamicInput(InputV3, ABC):
''' '''
Abstract class for dynamic input registration. Abstract class for dynamic input registration.
''' '''
def __init__(self, io_type: str, id: str, display_name: str=None): @abstractmethod
super().__init__(io_type, id, display_name) def get_dynamic(self) -> list[InputV3]:
...
class DynamicOutput(OutputV3): class DynamicOutput(OutputV3, ABC):
''' '''
Abstract class for dynamic output registration. Abstract class for dynamic output registration.
''' '''
def __init__(self, io_type: str, id: str, display_name: str=None): @abstractmethod
super().__init__(io_type, id, display_name) def get_dynamic(self) -> list[OutputV3]:
...
# io_type="COMFY_MULTIGROW_V3"
class AutoGrowDynamicInput(DynamicInput):
'''
Dynamic Input that adds another template_input each time one is provided.
Additional inputs are forced to have 'optional=True'. @comfytype(io_type="COMFY_AUTOGROW_V3")
''' class AutogrowDynamic:
def __init__(self, id: str, template_input: InputV3, min: int=1, max: int=None): Type = list[Any]
super().__init__("AutoGrowDynamicInput", id) class Input(DynamicInput):
self.template_input = template_input def __init__(self, id: str, template_input: InputV3, min: int=1, max: int=None,
if min is not None: display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
assert(min >= 1) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
if max is not None: self.template_input = template_input
assert(max >= 1) if min is not None:
self.min = min assert(min >= 1)
self.max = max if max is not None:
assert(max >= 1)
self.min = min
self.max = max
def get_dynamic(self) -> list[InputV3]:
curr_count = 1
new_inputs = []
for i in range(self.min):
new_input = copy.copy(self.template_input)
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
if new_input.display_name is not None:
new_input.display_name = f"{new_input.display_name}{curr_count}"
new_input.optional = self.optional or new_input.optional
if isinstance(self.template_input, WidgetInputV3):
new_input.force_input = True
new_inputs.append(new_input)
curr_count += 1
# pretend to expand up to max
for i in range(curr_count-1, self.max):
new_input = copy.copy(self.template_input)
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
if new_input.display_name is not None:
new_input.display_name = f"{new_input.display_name}{curr_count}"
new_input.optional = True
if isinstance(self.template_input, WidgetInputV3):
new_input.force_input = True
new_inputs.append(new_input)
curr_count += 1
return new_inputs
# io_type="COMFY_COMBODYNAMIC_V3" # io_type="COMFY_COMBODYNAMIC_V3"
class ComboDynamicInput(DynamicInput): class ComboDynamicInput(DynamicInput):
def __init__(self, id: str): def __init__(self, id: str):
pass pass
AutoGrowDynamicInput(id="dynamic", template_input=Image.Input(id="image"))
class HiddenHolder: class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any, def __init__(self, unique_id: str, prompt: Any,
@ -960,6 +989,11 @@ class classproperty(object):
return self.f(owner) return self.f(owner)
def add_to_dict_v1(i: InputV3, input: dict):
key = "optional" if i.optional else "required"
input.setdefault(key, {})[i.id] = (i.get_io_type_V1(), i.as_dict_V1())
class ComfyNodeV3: class ComfyNodeV3:
"""Common base class for all V3 nodes.""" """Common base class for all V3 nodes."""
@ -1117,8 +1151,12 @@ class ComfyNodeV3:
} }
if schema.inputs: if schema.inputs:
for i in schema.inputs: for i in schema.inputs:
key = "optional" if i.optional else "required" if isinstance(i, DynamicInput):
input.setdefault(key, {})[i.id] = (i.get_io_type_V1(), i.as_dict_V1()) dynamic_inputs = i.get_dynamic()
for d in dynamic_inputs:
add_to_dict_v1(d, input)
else:
add_to_dict_v1(i, input)
if schema.hidden and include_hidden: if schema.hidden and include_hidden:
for hidden in schema.hidden: for hidden in schema.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,) input.setdefault("hidden", {})[hidden.name] = (hidden.value,)

View File

@ -41,7 +41,7 @@ class ResourcesLocal(Resources):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.local_resources: dict[ResourceKey, Any] = {} self.local_resources: dict[ResourceKey, Any] = {}
def get(self, key: ResourceKey, default: Any=...) -> Any: def get(self, key: ResourceKey, default: Any=...) -> Any:
cached = self.local_resources.get(key, None) cached = self.local_resources.get(key, None)
if cached is not None: if cached is not None:

View File

@ -149,7 +149,37 @@ class V3LoraLoader(io.ComfyNodeV3):
return io.NodeOutput(model_lora, clip_lora) return io.NodeOutput(model_lora, clip_lora)
class NInputsTest(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="V3_NInputsTest",
display_name="V3 N Inputs Test",
inputs=[
io.AutogrowDynamic.Input("nmock", template_input=io.Image.Input("image"), min=1, max=3),
io.AutogrowDynamic.Input("nmock2", template_input=io.Int.Input("int"), optional=True, min=1, max=4),
],
outputs=[
io.Image.Output("image_out"),
],
)
@classmethod
def execute(cls, nmock, nmock2):
first_image = nmock[0]
all_images = []
for img in nmock:
if img.shape != first_image.shape:
img = img.movedim(-1,1)
img = comfy.utils.common_upscale(img, first_image.shape[2], first_image.shape[1], "lanczos", "center")
img = img.movedim(1,-1)
all_images.append(img)
combined_image = torch.cat(all_images, dim=0)
return io.NodeOutput(combined_image)
NODES_LIST: list[type[io.ComfyNodeV3]] = [ NODES_LIST: list[type[io.ComfyNodeV3]] = [
V3TestNode, V3TestNode,
V3LoraLoader, V3LoraLoader,
NInputsTest,
] ]

View File

@ -28,7 +28,7 @@ from comfy_execution.graph import (
) )
from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input from comfy_execution.validation import validate_node_input
from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden, NodeStateLocal, ResourcesLocal from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden, NodeStateLocal, ResourcesLocal, AutogrowDynamic
class ExecutionResult(Enum): class ExecutionResult(Enum):
@ -229,6 +229,16 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
if obj.local_resources is None: if obj.local_resources is None:
obj.local_resources = ResourcesLocal() obj.local_resources = ResourcesLocal()
class_clone.resources = obj.local_resources class_clone.resources = obj.local_resources
# TODO: delete this when done testing mocking dynamic inputs
for si in obj.SCHEMA.inputs:
if isinstance(si, AutogrowDynamic.Input):
add_key = si.id
dynamic_list = []
real_inputs = {k: v for k, v in inputs.items()}
for d in si.get_dynamic():
dynamic_list.append(real_inputs.pop(d.id, None))
dynamic_list = [x for x in dynamic_list if x is not None]
inputs = {**real_inputs, add_key: dynamic_list}
results.append(getattr(type(obj), func).__func__(class_clone, **inputs)) results.append(getattr(type(obj), func).__func__(class_clone, **inputs))
# V1 # V1
else: else: