mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-29 17:26:34 +00:00
Support async for v3's execute function, still need to test validate_inputs, fingerprint_inputs, and check_lazy_status, fix Any type for v3 by introducing __ne__ trick from comfy_api's typing.py
This commit is contained in:
parent
fd9c34a3eb
commit
b6a4a4c664
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
@ -118,9 +119,18 @@ def make_locked_method_func(type_obj, func, class_clone):
|
|||||||
"""
|
"""
|
||||||
Returns a function that, when called with **inputs, will execute:
|
Returns a function that, when called with **inputs, will execute:
|
||||||
getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)
|
getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)
|
||||||
|
|
||||||
|
Supports both synchronous and asynchronous methods.
|
||||||
"""
|
"""
|
||||||
locked_class = lock_class(class_clone)
|
locked_class = lock_class(class_clone)
|
||||||
method = getattr(type_obj, func).__func__
|
method = getattr(type_obj, func).__func__
|
||||||
def wrapped_func(**inputs):
|
|
||||||
return method(locked_class, **inputs)
|
# Check if the original method is async
|
||||||
return wrapped_func
|
if asyncio.iscoroutinefunction(method):
|
||||||
|
async def wrapped_async_func(**inputs):
|
||||||
|
return await method(locked_class, **inputs)
|
||||||
|
return wrapped_async_func
|
||||||
|
else:
|
||||||
|
def wrapped_func(**inputs):
|
||||||
|
return method(locked_class, **inputs)
|
||||||
|
return wrapped_func
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import inspect
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
@ -75,6 +76,16 @@ class NumberDisplay(str, Enum):
|
|||||||
color = "color"
|
color = "color"
|
||||||
|
|
||||||
|
|
||||||
|
class _StringIOType(str):
|
||||||
|
def __ne__(self, value: object) -> bool:
|
||||||
|
if self == "*" or value == "*":
|
||||||
|
return False
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return True
|
||||||
|
a = frozenset(self.split(","))
|
||||||
|
b = frozenset(value.split(","))
|
||||||
|
return not (b.issubset(a) or a.issubset(b))
|
||||||
|
|
||||||
class ComfyType(ABC):
|
class ComfyType(ABC):
|
||||||
Type = Any
|
Type = Any
|
||||||
io_type: str = None
|
io_type: str = None
|
||||||
@ -114,8 +125,8 @@ def comfytype(io_type: str, **kwargs):
|
|||||||
new_cls.__module__ = cls.__module__
|
new_cls.__module__ = cls.__module__
|
||||||
new_cls.__doc__ = cls.__doc__
|
new_cls.__doc__ = cls.__doc__
|
||||||
# assign ComfyType attributes, if needed
|
# assign ComfyType attributes, if needed
|
||||||
# NOTE: do we need __ne__ trick for io_type? (see node_typing.IO.__ne__ for details)
|
# NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details)
|
||||||
new_cls.io_type = io_type
|
new_cls.io_type = _StringIOType(io_type)
|
||||||
if hasattr(new_cls, "Input") and new_cls.Input is not None:
|
if hasattr(new_cls, "Input") and new_cls.Input is not None:
|
||||||
new_cls.Input.Parent = new_cls
|
new_cls.Input.Parent = new_cls
|
||||||
if hasattr(new_cls, "Output") and new_cls.Output is not None:
|
if hasattr(new_cls, "Output") and new_cls.Output is not None:
|
||||||
@ -169,7 +180,7 @@ class InputV3(IO_V3):
|
|||||||
}) | prune_dict(self.extra_dict)
|
}) | prune_dict(self.extra_dict)
|
||||||
|
|
||||||
def get_io_type(self):
|
def get_io_type(self):
|
||||||
return self.io_type
|
return _StringIOType(self.io_type)
|
||||||
|
|
||||||
class WidgetInputV3(InputV3):
|
class WidgetInputV3(InputV3):
|
||||||
'''
|
'''
|
||||||
@ -1227,6 +1238,12 @@ class _ComfyNodeBaseInternal(ComfyNodeInternal):
|
|||||||
if first_real_override(cls, "execute") is None:
|
if first_real_override(cls, "execute") is None:
|
||||||
raise Exception(f"No execute function was defined for node class {cls.__name__}.")
|
raise Exception(f"No execute function was defined for node class {cls.__name__}.")
|
||||||
|
|
||||||
|
@classproperty
|
||||||
|
def FUNCTION(cls): # noqa
|
||||||
|
if inspect.iscoroutinefunction(cls.execute):
|
||||||
|
return "EXECUTE_NORMALIZED_ASYNC"
|
||||||
|
return "EXECUTE_NORMALIZED"
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def EXECUTE_NORMALIZED(cls, *args, **kwargs) -> NodeOutput:
|
def EXECUTE_NORMALIZED(cls, *args, **kwargs) -> NodeOutput:
|
||||||
@ -1244,6 +1261,23 @@ class _ComfyNodeBaseInternal(ComfyNodeInternal):
|
|||||||
else:
|
else:
|
||||||
raise Exception(f"Invalid return type from node: {type(to_return)}")
|
raise Exception(f"Invalid return type from node: {type(to_return)}")
|
||||||
|
|
||||||
|
@final
|
||||||
|
@classmethod
|
||||||
|
async def EXECUTE_NORMALIZED_ASYNC(cls, *args, **kwargs) -> NodeOutput:
|
||||||
|
to_return = await cls.execute(*args, **kwargs)
|
||||||
|
if to_return is None:
|
||||||
|
return NodeOutput()
|
||||||
|
elif isinstance(to_return, NodeOutput):
|
||||||
|
return to_return
|
||||||
|
elif isinstance(to_return, tuple):
|
||||||
|
return NodeOutput(*to_return)
|
||||||
|
elif isinstance(to_return, dict):
|
||||||
|
return NodeOutput.from_dict(to_return)
|
||||||
|
elif isinstance(to_return, ExecutionBlocker):
|
||||||
|
return NodeOutput(block_execution=to_return.message)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Invalid return type from node: {type(to_return)}")
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNodeV3]:
|
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNodeV3]:
|
||||||
@ -1366,8 +1400,6 @@ class _ComfyNodeBaseInternal(ComfyNodeInternal):
|
|||||||
cls.GET_SCHEMA()
|
cls.GET_SCHEMA()
|
||||||
return cls._NOT_IDEMPOTENT
|
return cls._NOT_IDEMPOTENT
|
||||||
|
|
||||||
FUNCTION = "EXECUTE_NORMALIZED"
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], SchemaV3]:
|
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], SchemaV3]:
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
from comfy.comfy_types.node_typing import ComfyNodeABC, IO
|
from comfy.comfy_types.node_typing import ComfyNodeABC, IO
|
||||||
|
import asyncio
|
||||||
|
from comfy.utils import ProgressBar
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
class TestNode(ComfyNodeABC):
|
class TestNode(ComfyNodeABC):
|
||||||
@ -34,10 +37,41 @@ class TestNode(ComfyNodeABC):
|
|||||||
return (some_int, image)
|
return (some_int, image)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSleep(ComfyNodeABC):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": (IO.ANY, {}),
|
||||||
|
"seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = (IO.ANY,)
|
||||||
|
FUNCTION = "sleep"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
async def sleep(self, value, seconds, unique_id):
|
||||||
|
pbar = ProgressBar(seconds, node_id=unique_id)
|
||||||
|
start = time.time()
|
||||||
|
expiration = start + seconds
|
||||||
|
now = start
|
||||||
|
while now < expiration:
|
||||||
|
now = time.time()
|
||||||
|
pbar.update_absolute(now - start)
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
return (value,)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"V1TestNode1": TestNode,
|
"V1TestNode1": TestNode,
|
||||||
|
"V1TestSleep": TestSleep,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"V1TestNode1": "V1 Test Node",
|
"V1TestNode1": "V1 Test Node",
|
||||||
|
"V1TestSleep": "V1 Test Sleep",
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import logging # noqa
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
@io.comfytype(io_type="XYZ")
|
@io.comfytype(io_type="XYZ")
|
||||||
@ -203,8 +204,43 @@ class NInputsTest(io.ComfyNodeV3):
|
|||||||
return io.NodeOutput(combined_image)
|
return io.NodeOutput(combined_image)
|
||||||
|
|
||||||
|
|
||||||
|
class V3TestSleep(io.ComfyNodeV3):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.SchemaV3(
|
||||||
|
node_id="V3_TestSleep",
|
||||||
|
display_name="V3 Test Sleep",
|
||||||
|
category="_for_testing",
|
||||||
|
description="Test async sleep functionality.",
|
||||||
|
inputs=[
|
||||||
|
io.AnyType.Input("value", display_name="Value"),
|
||||||
|
io.Float.Input("seconds", display_name="Seconds", default=1.0, min=0.0, max=9999.0, step=0.01, tooltip="The amount of seconds to sleep."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.AnyType.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(cls, value: io.AnyType.Type, seconds: io.Float.Type, **kwargs):
|
||||||
|
logging.info(f"V3TestSleep: {cls.hidden.unique_id}")
|
||||||
|
pbar = comfy.utils.ProgressBar(seconds, node_id=cls.hidden.unique_id)
|
||||||
|
start = time.time()
|
||||||
|
expiration = start + seconds
|
||||||
|
now = start
|
||||||
|
while now < expiration:
|
||||||
|
now = time.time()
|
||||||
|
pbar.update_absolute(now - start)
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
return io.NodeOutput(value)
|
||||||
|
|
||||||
|
|
||||||
NODES_LIST: list[type[io.ComfyNodeV3]] = [
|
NODES_LIST: list[type[io.ComfyNodeV3]] = [
|
||||||
V3TestNode,
|
V3TestNode,
|
||||||
V3LoraLoader,
|
V3LoraLoader,
|
||||||
NInputsTest,
|
NInputsTest,
|
||||||
|
V3TestSleep,
|
||||||
]
|
]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user