Merge branch 'v3-definition' of https://github.com/comfyanonymous/ComfyUI into v3-definition

This commit is contained in:
kosinkadink1@gmail.com 2025-07-09 03:58:16 -05:00
commit 5f91e2905a
6 changed files with 268 additions and 81 deletions

View File

@ -1,10 +1,12 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast, TypedDict, NotRequired from typing import Any, Literal, TYPE_CHECKING, TypeVar, Callable, Optional, cast, TypedDict
from typing_extensions import NotRequired
from enum import Enum from enum import Enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from collections import Counter from collections import Counter
from comfy_api.v3.resources import Resources, ResourcesLocal from comfy_api.v3.resources import Resources, ResourcesLocal
import copy
# used for type hinting # used for type hinting
import torch import torch
from spandrel import ImageModelDescriptor from spandrel import ImageModelDescriptor
@ -189,17 +191,19 @@ class WidgetInputV3(InputV3):
''' '''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: Any=None, default: Any=None,
socketless: bool=None, widgetType: str=None, extra_dict=None): socketless: bool=None, widgetType: str=None, force_input: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.default = default self.default = default
self.socketless = socketless self.socketless = socketless
self.widgetType = widgetType self.widgetType = widgetType
self.force_input = force_input
def as_dict_V1(self): def as_dict_V1(self):
return super().as_dict_V1() | prune_dict({ return super().as_dict_V1() | prune_dict({
"default": self.default, "default": self.default,
"socketless": self.socketless, "socketless": self.socketless,
"widgetType": self.widgetType, "widgetType": self.widgetType,
"forceInput": self.force_input,
}) })
def get_io_type_V1(self): def get_io_type_V1(self):
@ -291,8 +295,8 @@ class Boolean:
'''Boolean input.''' '''Boolean input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: bool=None, label_on: str=None, label_off: str=None, default: bool=None, label_on: str=None, label_off: str=None,
socketless: bool=None): socketless: bool=None, force_input: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, force_input)
self.label_on = label_on self.label_on = label_on
self.label_off = label_off self.label_off = label_off
self.default: bool self.default: bool
@ -314,8 +318,8 @@ class Int:
'''Integer input.''' '''Integer input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
display_mode: NumberDisplay=None, socketless: bool=None): display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, force_input)
self.min = min self.min = min
self.max = max self.max = max
self.step = step self.step = step
@ -343,8 +347,8 @@ class Float(ComfyTypeIO):
'''Float input.''' '''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, socketless: bool=None): display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, force_input)
self.min = min self.min = min
self.max = max self.max = max
self.step = step self.step = step
@ -369,8 +373,8 @@ class String(ComfyTypeIO):
'''String input.''' '''String input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
multiline=False, placeholder: str=None, default: int=None, multiline=False, placeholder: str=None, default: int=None,
socketless: bool=None): socketless: bool=None, force_input: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, self.io_type, force_input)
self.multiline = multiline self.multiline = multiline
self.placeholder = placeholder self.placeholder = placeholder
self.default: str self.default: str
@ -429,7 +433,7 @@ class MultiCombo(ComfyType):
def as_dict_V1(self): def as_dict_V1(self):
to_return = super().as_dict_V1() | prune_dict({ to_return = super().as_dict_V1() | prune_dict({
"multiselect": self.multiselect, "multi_select": self.multiselect,
"placeholder": self.placeholder, "placeholder": self.placeholder,
"chip": self.chip, "chip": self.chip,
}) })
@ -754,29 +758,30 @@ class MultiType:
else: else:
return super().as_dict_V1() return super().as_dict_V1()
class DynamicInput(InputV3): class DynamicInput(InputV3, ABC):
''' '''
Abstract class for dynamic input registration. Abstract class for dynamic input registration.
''' '''
def __init__(self, io_type: str, id: str, display_name: str=None): @abstractmethod
super().__init__(io_type, id, display_name) def get_dynamic(self) -> list[InputV3]:
...
class DynamicOutput(OutputV3): class DynamicOutput(OutputV3, ABC):
''' '''
Abstract class for dynamic output registration. Abstract class for dynamic output registration.
''' '''
def __init__(self, io_type: str, id: str, display_name: str=None): @abstractmethod
super().__init__(io_type, id, display_name) def get_dynamic(self) -> list[OutputV3]:
...
# io_type="COMFY_MULTIGROW_V3"
class AutoGrowDynamicInput(DynamicInput):
'''
Dynamic Input that adds another template_input each time one is provided.
Additional inputs are forced to have 'optional=True'. @comfytype(io_type="COMFY_AUTOGROW_V3")
''' class AutogrowDynamic:
def __init__(self, id: str, template_input: InputV3, min: int=1, max: int=None): Type = list[Any]
super().__init__("AutoGrowDynamicInput", id) class Input(DynamicInput):
def __init__(self, id: str, template_input: InputV3, min: int=1, max: int=None,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template_input = template_input self.template_input = template_input
if min is not None: if min is not None:
assert(min >= 1) assert(min >= 1)
@ -785,13 +790,37 @@ class AutoGrowDynamicInput(DynamicInput):
self.min = min self.min = min
self.max = max self.max = max
def get_dynamic(self) -> list[InputV3]:
curr_count = 1
new_inputs = []
for i in range(self.min):
new_input = copy.copy(self.template_input)
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
if new_input.display_name is not None:
new_input.display_name = f"{new_input.display_name}{curr_count}"
new_input.optional = self.optional or new_input.optional
if isinstance(self.template_input, WidgetInputV3):
new_input.force_input = True
new_inputs.append(new_input)
curr_count += 1
# pretend to expand up to max
for i in range(curr_count-1, self.max):
new_input = copy.copy(self.template_input)
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
if new_input.display_name is not None:
new_input.display_name = f"{new_input.display_name}{curr_count}"
new_input.optional = True
if isinstance(self.template_input, WidgetInputV3):
new_input.force_input = True
new_inputs.append(new_input)
curr_count += 1
return new_inputs
# io_type="COMFY_COMBODYNAMIC_V3" # io_type="COMFY_COMBODYNAMIC_V3"
class ComboDynamicInput(DynamicInput): class ComboDynamicInput(DynamicInput):
def __init__(self, id: str): def __init__(self, id: str):
pass pass
AutoGrowDynamicInput(id="dynamic", template_input=Image.Input(id="image"))
class HiddenHolder: class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any, def __init__(self, unique_id: str, prompt: Any,
@ -815,7 +844,9 @@ class HiddenHolder:
return None return None
@classmethod @classmethod
def from_dict(cls, d: dict): def from_dict(cls, d: dict | None):
if d is None:
d = {}
return cls( return cls(
unique_id=d.get(Hidden.unique_id, None), unique_id=d.get(Hidden.unique_id, None),
prompt=d.get(Hidden.prompt, None), prompt=d.get(Hidden.prompt, None),
@ -939,6 +970,26 @@ class SchemaV3:
if len(issues) > 0: if len(issues) > 0:
raise ValueError("\n".join(issues)) raise ValueError("\n".join(issues))
def finalize(self):
"""Add hidden based on selected schema options."""
# if is an api_node, will need key-related hidden
if self.is_api_node:
if self.hidden is None:
self.hidden = []
if Hidden.auth_token_comfy_org not in self.hidden:
self.hidden.append(Hidden.auth_token_comfy_org)
if Hidden.api_key_comfy_org not in self.hidden:
self.hidden.append(Hidden.api_key_comfy_org)
# if is an output_node, will need prompt and extra_pnginfo
if self.is_output_node:
if self.hidden is None:
self.hidden = []
if Hidden.prompt not in self.hidden:
self.hidden.append(Hidden.prompt)
if Hidden.extra_pnginfo not in self.hidden:
self.hidden.append(Hidden.extra_pnginfo)
class Serializer: class Serializer:
def __init_subclass__(cls, io_type: str, **kwargs): def __init_subclass__(cls, io_type: str, **kwargs):
cls.io_type = io_type cls.io_type = io_type
@ -960,6 +1011,11 @@ class classproperty(object):
return self.f(owner) return self.f(owner)
def add_to_dict_v1(i: InputV3, input: dict):
key = "optional" if i.optional else "required"
input.setdefault(key, {})[i.id] = (i.get_io_type_V1(), i.as_dict_V1())
class ComfyNodeV3: class ComfyNodeV3:
"""Common base class for all V3 nodes.""" """Common base class for all V3 nodes."""
@ -971,12 +1027,6 @@ class ComfyNodeV3:
resources: Resources = None resources: Resources = None
hidden: HiddenHolder = None hidden: HiddenHolder = None
@classmethod
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
schema = cls.GET_SCHEMA()
# TODO: finish
return None
@classmethod @classmethod
@abstractmethod @abstractmethod
def DEFINE_SCHEMA(cls) -> SchemaV3: def DEFINE_SCHEMA(cls) -> SchemaV3:
@ -992,10 +1042,46 @@ class ComfyNodeV3:
pass pass
execute = None execute = None
@classmethod
def validate_inputs(cls, **kwargs) -> bool:
"""Optionally, define this function to validate inputs; equivalnet to V1's VALIDATE_INPUTS."""
pass
validate_inputs = None
@classmethod
def fingerprint_inputs(cls, **kwargs) -> Any:
"""Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED."""
pass
fingerprint_inputs = None
@classmethod
def check_lazy_status(cls, **kwargs) -> list[str]:
"""Optionally, define this function to return a list of input names that should be evaluated.
This basic mixin impl. requires all inputs.
:kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \
When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``.
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status
"""
need = [name for name in kwargs if kwargs[name] is None]
return need
check_lazy_status = None
@classmethod @classmethod
def GET_SERIALIZERS(cls) -> list[Serializer]: def GET_SERIALIZERS(cls) -> list[Serializer]:
return [] return []
@classmethod
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
schema = cls.GET_SCHEMA()
# TODO: finish
return None
def __init__(self): def __init__(self):
self.local_state: NodeStateLocal = None self.local_state: NodeStateLocal = None
self.local_resources: ResourcesLocal = None self.local_resources: ResourcesLocal = None
@ -1110,15 +1196,19 @@ class ComfyNodeV3:
@classmethod @classmethod
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], SchemaV3]: def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], SchemaV3]:
schema = cls.DEFINE_SCHEMA() schema = cls.FINALIZE_SCHEMA()
# for V1, make inputs be a dict with potential keys {required, optional, hidden} # for V1, make inputs be a dict with potential keys {required, optional, hidden}
input = { input = {
"required": {} "required": {}
} }
if schema.inputs: if schema.inputs:
for i in schema.inputs: for i in schema.inputs:
key = "optional" if i.optional else "required" if isinstance(i, DynamicInput):
input.setdefault(key, {})[i.id] = (i.get_io_type_V1(), i.as_dict_V1()) dynamic_inputs = i.get_dynamic()
for d in dynamic_inputs:
add_to_dict_v1(d, input)
else:
add_to_dict_v1(i, input)
if schema.hidden and include_hidden: if schema.hidden and include_hidden:
for hidden in schema.hidden: for hidden in schema.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,) input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
@ -1127,9 +1217,17 @@ class ComfyNodeV3:
return input return input
@classmethod @classmethod
def GET_SCHEMA(cls) -> SchemaV3: def FINALIZE_SCHEMA(cls):
cls.VALIDATE_CLASS() """Call DEFINE_SCHEMA and finalize it."""
schema = cls.DEFINE_SCHEMA() schema = cls.DEFINE_SCHEMA()
schema.finalize()
return schema
@classmethod
def GET_SCHEMA(cls) -> SchemaV3:
"""Validate node class, finalize schema, validate schema, and set expected class properties."""
cls.VALIDATE_CLASS()
schema = cls.FINALIZE_SCHEMA()
schema.validate() schema.validate()
if cls._DESCRIPTION is None: if cls._DESCRIPTION is None:
cls._DESCRIPTION = schema.description cls._DESCRIPTION = schema.description

View File

@ -1,12 +1,15 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from comfy_api.v3.io import Image, Mask, FolderType, _UIOutput from comfy_api.v3.io import Image, Mask, FolderType, _UIOutput, ComfyNodeV3
# used for image preview # used for image preview
from comfy.cli_args import args
import folder_paths import folder_paths
import random import random
from PIL import Image as PILImage from PIL import Image as PILImage
from PIL.PngImagePlugin import PngInfo
import os import os
import json
import numpy as np import numpy as np
@ -24,7 +27,7 @@ class SavedResult:
} }
class PreviewImage(_UIOutput): class PreviewImage(_UIOutput):
def __init__(self, image: Image.Type, animated: bool=False, **kwargs): def __init__(self, image: Image.Type, animated: bool=False, node: ComfyNodeV3=None, **kwargs):
output_dir = folder_paths.get_temp_directory() output_dir = folder_paths.get_temp_directory()
type = "temp" type = "temp"
prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
@ -38,13 +41,13 @@ class PreviewImage(_UIOutput):
i = 255. * image.cpu().numpy() i = 255. * image.cpu().numpy()
img = PILImage.fromarray(np.clip(i, 0, 255).astype(np.uint8)) img = PILImage.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = None metadata = None
# if not args.disable_metadata: if not args.disable_metadata and node is not None:
# metadata = PngInfo() metadata = PngInfo()
# if prompt is not None: if node.hidden.prompt is not None:
# metadata.add_text("prompt", json.dumps(prompt)) metadata.add_text("prompt", json.dumps(node.hidden.prompt))
# if extra_pnginfo is not None: if node.hidden.extra_pnginfo is not None:
# for x in extra_pnginfo: for x in node.hidden.extra_pnginfo:
# metadata.add_text(x, json.dumps(extra_pnginfo[x])) metadata.add_text(x, json.dumps(node.hidden.extra_pnginfo[x]))
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png" file = f"{filename_with_batch_num}_{counter:05}_.png"
@ -63,9 +66,9 @@ class PreviewImage(_UIOutput):
} }
class PreviewMask(PreviewImage): class PreviewMask(PreviewImage):
def __init__(self, mask: PreviewMask.Type, animated: bool=False, **kwargs): def __init__(self, mask: PreviewMask.Type, animated: bool=False, node: ComfyNodeV3=None, **kwargs):
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
super().__init__(preview, animated, **kwargs) super().__init__(preview, animated, node, **kwargs)
# class UILatent(_UIOutput): # class UILatent(_UIOutput):
# def __init__(self, values: list[SavedResult | dict], **kwargs): # def __init__(self, values: list[SavedResult | dict], **kwargs):

View File

@ -13,6 +13,7 @@ class TestNode(ComfyNodeABC):
"min": 0, "max": 127, "default": 42, "min": 0, "max": 127, "default": 42,
"tooltip": "My tooltip 😎", "display": "slider"}), "tooltip": "My tooltip 😎", "display": "slider"}),
"combo": (IO.COMBO, {"options": ["a", "b", "c"], "tooltip": "This is a combo input"}), "combo": (IO.COMBO, {"options": ["a", "b", "c"], "tooltip": "This is a combo input"}),
"combo2": (IO.COMBO, {"options": ["a", "b", "c"], "multi_select": True, "tooltip": "This is a combo input"}),
}, },
"optional": { "optional": {
"xyz": ("XYZ",), "xyz": ("XYZ",),
@ -29,7 +30,7 @@ class TestNode(ComfyNodeABC):
CATEGORY = "v3 nodes" CATEGORY = "v3 nodes"
def do_thing(self, image: torch.Tensor, some_int: int, combo: str, xyz=None, mask: torch.Tensor=None): def do_thing(self, image: torch.Tensor, some_int: int, combo: str, combo2: list[str], xyz=None, mask: torch.Tensor=None):
return (some_int, image) return (some_int, image)

View File

@ -1,4 +1,5 @@
import torch import torch
import time
from comfy_api.v3 import io, ui, resources from comfy_api.v3 import io, ui, resources
import logging import logging
import folder_paths import folder_paths
@ -72,6 +73,14 @@ class V3TestNode(io.ComfyNodeV3):
is_output_node=True, is_output_node=True,
) )
@classmethod
def validate_inputs(cls, image: io.Image.Type, some_int: int, combo: io.Combo.Type, combo2: io.MultiCombo.Type, xyz: XYZ.Type=None, mask: io.Mask.Type=None, **kwargs):
if some_int < 0:
raise Exception("some_int must be greater than 0")
if combo == "c":
raise Exception("combo must be a or b")
return True
@classmethod @classmethod
def execute(cls, image: io.Image.Type, some_int: int, combo: io.Combo.Type, combo2: io.MultiCombo.Type, xyz: XYZ.Type=None, mask: io.Mask.Type=None, **kwargs): def execute(cls, image: io.Image.Type, some_int: int, combo: io.Combo.Type, combo2: io.MultiCombo.Type, xyz: XYZ.Type=None, mask: io.Mask.Type=None, **kwargs):
zzz = cls.hidden.prompt zzz = cls.hidden.prompt
@ -149,7 +158,50 @@ class V3LoraLoader(io.ComfyNodeV3):
return io.NodeOutput(model_lora, clip_lora) return io.NodeOutput(model_lora, clip_lora)
class NInputsTest(io.ComfyNodeV3):
@classmethod
def DEFINE_SCHEMA(cls):
return io.SchemaV3(
node_id="V3_NInputsTest",
display_name="V3 N Inputs Test",
inputs=[
io.AutogrowDynamic.Input("nmock", template_input=io.Image.Input("image"), min=1, max=3),
io.AutogrowDynamic.Input("nmock2", template_input=io.Int.Input("int"), optional=True, min=1, max=4),
],
outputs=[
io.Image.Output("image_out"),
],
)
@classmethod
def validate_inputs(cls, nmock, nmock2):
return True
@classmethod
def fingerprint_inputs(cls, nmock, nmock2):
return time.time()
@classmethod
def check_lazy_status(cls, **kwargs) -> list[str]:
need = [name for name in kwargs if kwargs[name] is None]
return need
@classmethod
def execute(cls, nmock, nmock2):
first_image = nmock[0]
all_images = []
for img in nmock:
if img.shape != first_image.shape:
img = img.movedim(-1,1)
img = comfy.utils.common_upscale(img, first_image.shape[2], first_image.shape[1], "lanczos", "center")
img = img.movedim(1,-1)
all_images.append(img)
combined_image = torch.cat(all_images, dim=0)
return io.NodeOutput(combined_image)
NODES_LIST: list[type[io.ComfyNodeV3]] = [ NODES_LIST: list[type[io.ComfyNodeV3]] = [
V3TestNode, V3TestNode,
V3LoraLoader, V3LoraLoader,
NInputsTest,
] ]

View File

@ -28,7 +28,7 @@ from comfy_execution.graph import (
) )
from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input from comfy_execution.validation import validate_node_input
from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden, NodeStateLocal, ResourcesLocal from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden, NodeStateLocal, ResourcesLocal, AutogrowDynamic, is_class
class ExecutionResult(Enum): class ExecutionResult(Enum):
@ -52,7 +52,15 @@ class IsChangedCache:
node = self.dynprompt.get_node(node_id) node = self.dynprompt.get_node(node_id)
class_type = node["class_type"] class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if not hasattr(class_def, "IS_CHANGED"): has_is_changed = False
is_changed_name = None
if issubclass(class_def, ComfyNodeV3) and getattr(class_def, "fingerprint_inputs", None) is not None:
has_is_changed = True
is_changed_name = "fingerprint_inputs"
elif hasattr(class_def, "IS_CHANGED"):
has_is_changed = True
is_changed_name = "IS_CHANGED"
if not has_is_changed:
self.is_changed[node_id] = False self.is_changed[node_id] = False
return self.is_changed[node_id] return self.is_changed[node_id]
@ -63,7 +71,7 @@ class IsChangedCache:
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None) input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
try: try:
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED") is_changed = _map_node_over_list(class_def, input_data_all, is_changed_name)
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e: except Exception as e:
logging.warning("WARNING: {}".format(e)) logging.warning("WARNING: {}".format(e))
@ -216,7 +224,15 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
if pre_execute_cb is not None and index is not None: if pre_execute_cb is not None and index is not None:
pre_execute_cb(index) pre_execute_cb(index)
# V3 # V3
if isinstance(obj, ComfyNodeV3): if isinstance(obj, ComfyNodeV3) or (is_class(obj) and issubclass(obj, ComfyNodeV3)):
# if is just a class, then assign no resources or state, just create clone
if is_class(obj):
type_obj = obj
obj.VALIDATE_CLASS()
class_clone = obj.prepare_class_clone(hidden_inputs)
# otherwise, use class instance to populate/reuse some fields
else:
type_obj = type(obj)
type(obj).VALIDATE_CLASS() type(obj).VALIDATE_CLASS()
class_clone = type(obj).prepare_class_clone(hidden_inputs) class_clone = type(obj).prepare_class_clone(hidden_inputs)
# NOTE: this is a mock of state management; for local, just stores NodeStateLocal on node instance # NOTE: this is a mock of state management; for local, just stores NodeStateLocal on node instance
@ -229,7 +245,17 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
if obj.local_resources is None: if obj.local_resources is None:
obj.local_resources = ResourcesLocal() obj.local_resources = ResourcesLocal()
class_clone.resources = obj.local_resources class_clone.resources = obj.local_resources
results.append(getattr(type(obj), func).__func__(class_clone, **inputs)) # TODO: delete this when done testing mocking dynamic inputs
for si in obj.SCHEMA.inputs:
if isinstance(si, AutogrowDynamic.Input):
add_key = si.id
dynamic_list = []
real_inputs = {k: v for k, v in inputs.items()}
for d in si.get_dynamic():
dynamic_list.append(real_inputs.pop(d.id, None))
dynamic_list = [x for x in dynamic_list if x is not None]
inputs = {**real_inputs, add_key: dynamic_list}
results.append(getattr(type_obj, func).__func__(class_clone, **inputs))
# V1 # V1
else: else:
results.append(getattr(obj, func)(**inputs)) results.append(getattr(obj, func)(**inputs))
@ -382,7 +408,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
obj = class_def() obj = class_def()
caches.objects.set(unique_id, obj) caches.objects.set(unique_id, obj)
if hasattr(obj, "check_lazy_status"): if getattr(obj, "check_lazy_status", None) is not None:
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs) required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and ( required_inputs = [x for x in required_inputs if isinstance(x,str) and (
@ -641,8 +667,16 @@ def validate_inputs(prompt, item, validated):
validate_function_inputs = [] validate_function_inputs = []
validate_has_kwargs = False validate_has_kwargs = False
if hasattr(obj_class, "VALIDATE_INPUTS"): validate_function_name = None
argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS) validate_function = None
if issubclass(obj_class, ComfyNodeV3):
validate_function_name = "validate_inputs"
validate_function = getattr(obj_class, validate_function_name, None)
else:
validate_function_name = "VALIDATE_INPUTS"
validate_function = getattr(obj_class, validate_function_name, None)
if validate_function is not None:
argspec = inspect.getfullargspec(validate_function)
validate_function_inputs = argspec.args validate_function_inputs = argspec.args
validate_has_kwargs = argspec.varkw is not None validate_has_kwargs = argspec.varkw is not None
received_types = {} received_types = {}
@ -825,8 +859,7 @@ def validate_inputs(prompt, item, validated):
if 'input_types' in validate_function_inputs: if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types] input_filtered['input_types'] = [received_types]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered) ret = _map_node_over_list(obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS", hidden_inputs=hidden_inputs)
for x in input_filtered: for x in input_filtered:
for i, r in enumerate(ret): for i, r in enumerate(ret):
if r is not True and not isinstance(r, ExecutionBlocker): if r is not True and not isinstance(r, ExecutionBlocker):