Support async for v3's execute function, still need to test validate_inputs, fingerprint_inputs, and check_lazy_status, fix Any type for v3 by introducing __ne__ trick from comfy_api's typing.py

This commit is contained in:
Jedrzej Kosinski 2025-07-18 15:50:42 -07:00
parent fd9c34a3eb
commit b6a4a4c664
4 changed files with 120 additions and 8 deletions

View File

@ -1,3 +1,4 @@
import asyncio
from dataclasses import asdict from dataclasses import asdict
from typing import Callable, Optional from typing import Callable, Optional
@ -118,9 +119,18 @@ def make_locked_method_func(type_obj, func, class_clone):
""" """
Returns a function that, when called with **inputs, will execute: Returns a function that, when called with **inputs, will execute:
getattr(type_obj, func).__func__(lock_class(class_clone), **inputs) getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)
Supports both synchronous and asynchronous methods.
""" """
locked_class = lock_class(class_clone) locked_class = lock_class(class_clone)
method = getattr(type_obj, func).__func__ method = getattr(type_obj, func).__func__
def wrapped_func(**inputs):
return method(locked_class, **inputs) # Check if the original method is async
return wrapped_func if asyncio.iscoroutinefunction(method):
async def wrapped_async_func(**inputs):
return await method(locked_class, **inputs)
return wrapped_async_func
else:
def wrapped_func(**inputs):
return method(locked_class, **inputs)
return wrapped_func

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import copy import copy
import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import Counter from collections import Counter
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
@ -75,6 +76,16 @@ class NumberDisplay(str, Enum):
color = "color" color = "color"
class _StringIOType(str):
def __ne__(self, value: object) -> bool:
if self == "*" or value == "*":
return False
if not isinstance(value, str):
return True
a = frozenset(self.split(","))
b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b))
class ComfyType(ABC): class ComfyType(ABC):
Type = Any Type = Any
io_type: str = None io_type: str = None
@ -114,8 +125,8 @@ def comfytype(io_type: str, **kwargs):
new_cls.__module__ = cls.__module__ new_cls.__module__ = cls.__module__
new_cls.__doc__ = cls.__doc__ new_cls.__doc__ = cls.__doc__
# assign ComfyType attributes, if needed # assign ComfyType attributes, if needed
# NOTE: do we need __ne__ trick for io_type? (see node_typing.IO.__ne__ for details) # NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details)
new_cls.io_type = io_type new_cls.io_type = _StringIOType(io_type)
if hasattr(new_cls, "Input") and new_cls.Input is not None: if hasattr(new_cls, "Input") and new_cls.Input is not None:
new_cls.Input.Parent = new_cls new_cls.Input.Parent = new_cls
if hasattr(new_cls, "Output") and new_cls.Output is not None: if hasattr(new_cls, "Output") and new_cls.Output is not None:
@ -169,7 +180,7 @@ class InputV3(IO_V3):
}) | prune_dict(self.extra_dict) }) | prune_dict(self.extra_dict)
def get_io_type(self): def get_io_type(self):
return self.io_type return _StringIOType(self.io_type)
class WidgetInputV3(InputV3): class WidgetInputV3(InputV3):
''' '''
@ -1227,6 +1238,12 @@ class _ComfyNodeBaseInternal(ComfyNodeInternal):
if first_real_override(cls, "execute") is None: if first_real_override(cls, "execute") is None:
raise Exception(f"No execute function was defined for node class {cls.__name__}.") raise Exception(f"No execute function was defined for node class {cls.__name__}.")
@classproperty
def FUNCTION(cls): # noqa
if inspect.iscoroutinefunction(cls.execute):
return "EXECUTE_NORMALIZED_ASYNC"
return "EXECUTE_NORMALIZED"
@final @final
@classmethod @classmethod
def EXECUTE_NORMALIZED(cls, *args, **kwargs) -> NodeOutput: def EXECUTE_NORMALIZED(cls, *args, **kwargs) -> NodeOutput:
@ -1244,6 +1261,23 @@ class _ComfyNodeBaseInternal(ComfyNodeInternal):
else: else:
raise Exception(f"Invalid return type from node: {type(to_return)}") raise Exception(f"Invalid return type from node: {type(to_return)}")
@final
@classmethod
async def EXECUTE_NORMALIZED_ASYNC(cls, *args, **kwargs) -> NodeOutput:
to_return = await cls.execute(*args, **kwargs)
if to_return is None:
return NodeOutput()
elif isinstance(to_return, NodeOutput):
return to_return
elif isinstance(to_return, tuple):
return NodeOutput(*to_return)
elif isinstance(to_return, dict):
return NodeOutput.from_dict(to_return)
elif isinstance(to_return, ExecutionBlocker):
return NodeOutput(block_execution=to_return.message)
else:
raise Exception(f"Invalid return type from node: {type(to_return)}")
@final @final
@classmethod @classmethod
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNodeV3]: def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNodeV3]:
@ -1366,8 +1400,6 @@ class _ComfyNodeBaseInternal(ComfyNodeInternal):
cls.GET_SCHEMA() cls.GET_SCHEMA()
return cls._NOT_IDEMPOTENT return cls._NOT_IDEMPOTENT
FUNCTION = "EXECUTE_NORMALIZED"
@final @final
@classmethod @classmethod
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], SchemaV3]: def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], SchemaV3]:

View File

@ -1,5 +1,8 @@
import torch import torch
from comfy.comfy_types.node_typing import ComfyNodeABC, IO from comfy.comfy_types.node_typing import ComfyNodeABC, IO
import asyncio
from comfy.utils import ProgressBar
import time
class TestNode(ComfyNodeABC): class TestNode(ComfyNodeABC):
@ -34,10 +37,41 @@ class TestNode(ComfyNodeABC):
return (some_int, image) return (some_int, image)
class TestSleep(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "sleep"
CATEGORY = "_for_testing"
async def sleep(self, value, seconds, unique_id):
pbar = ProgressBar(seconds, node_id=unique_id)
start = time.time()
expiration = start + seconds
now = start
while now < expiration:
now = time.time()
pbar.update_absolute(now - start)
await asyncio.sleep(0.02)
return (value,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"V1TestNode1": TestNode, "V1TestNode1": TestNode,
"V1TestSleep": TestSleep,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"V1TestNode1": "V1 Test Node", "V1TestNode1": "V1 Test Node",
"V1TestSleep": "V1 Test Sleep",
} }

View File

@ -5,6 +5,7 @@ import logging # noqa
import folder_paths import folder_paths
import comfy.utils import comfy.utils
import comfy.sd import comfy.sd
import asyncio
@io.comfytype(io_type="XYZ") @io.comfytype(io_type="XYZ")
@ -203,8 +204,43 @@ class NInputsTest(io.ComfyNodeV3):
return io.NodeOutput(combined_image) return io.NodeOutput(combined_image)
class V3TestSleep(io.ComfyNodeV3):
@classmethod
def define_schema(cls):
return io.SchemaV3(
node_id="V3_TestSleep",
display_name="V3 Test Sleep",
category="_for_testing",
description="Test async sleep functionality.",
inputs=[
io.AnyType.Input("value", display_name="Value"),
io.Float.Input("seconds", display_name="Seconds", default=1.0, min=0.0, max=9999.0, step=0.01, tooltip="The amount of seconds to sleep."),
],
outputs=[
io.AnyType.Output(),
],
hidden=[
io.Hidden.unique_id,
],
)
@classmethod
async def execute(cls, value: io.AnyType.Type, seconds: io.Float.Type, **kwargs):
logging.info(f"V3TestSleep: {cls.hidden.unique_id}")
pbar = comfy.utils.ProgressBar(seconds, node_id=cls.hidden.unique_id)
start = time.time()
expiration = start + seconds
now = start
while now < expiration:
now = time.time()
pbar.update_absolute(now - start)
await asyncio.sleep(0.02)
return io.NodeOutput(value)
NODES_LIST: list[type[io.ComfyNodeV3]] = [ NODES_LIST: list[type[io.ComfyNodeV3]] = [
V3TestNode, V3TestNode,
V3LoraLoader, V3LoraLoader,
NInputsTest, NInputsTest,
V3TestSleep,
] ]