mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 08:16:44 +00:00
* Support for async execution functions This commit adds support for node execution functions defined as async. When a node's execution function is defined as async, we can continue executing other nodes while it is processing. Standard uses of `await` should "just work", but people will still have to be careful if they spawn actual threads. Because torch doesn't really have async/await versions of functions, this won't particularly help with most locally-executing nodes, but it does work for e.g. web requests to other machines. In addition to the execute function, the `VALIDATE_INPUTS` and `check_lazy_status` functions can also be defined as async, though we'll only resolve one node at a time right now for those. * Add the execution model tests to CI * Add a missing file It looks like this got caught by .gitignore? There's probably a better place to put it, but I'm not sure what that is. * Add the websocket library for automated tests * Add additional tests for async error cases Also fixes one bug that was found when an async function throws an error after being scheduled on a task. * Add a feature flags message to reduce bandwidth We now only send 1 preview message of the latest type the client can support. We'll add a console warning when the client fails to send a feature flags message at some point in the future. * Add async tests to CI * Don't actually add new tests in this PR Will do it in a separate PR * Resolve unit test in GPU-less runner * Just remove the tests that GHA can't handle * Change line endings to UNIX-style * Avoid loading model_management.py so early Because model_management.py has a top-level `logging.info`, we have to be careful not to import that file before we call `setup_logging`. If we do, we end up having the default logging handler registered in addition to our custom one.
344 lines
10 KiB
Python
344 lines
10 KiB
Python
import torch
|
|
import asyncio
|
|
from typing import Dict
|
|
from comfy.utils import ProgressBar
|
|
from comfy_execution.graph_utils import GraphBuilder
|
|
from comfy.comfy_types.node_typing import ComfyNodeABC
|
|
from comfy.comfy_types import IO
|
|
|
|
|
|
class TestAsyncValidation(ComfyNodeABC):
|
|
"""Test node with async VALIDATE_INPUTS."""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"value": ("FLOAT", {"default": 5.0}),
|
|
"threshold": ("FLOAT", {"default": 10.0}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "process"
|
|
CATEGORY = "_for_testing/async"
|
|
|
|
@classmethod
|
|
async def VALIDATE_INPUTS(cls, value, threshold):
|
|
# Simulate async validation (e.g., checking remote service)
|
|
await asyncio.sleep(0.05)
|
|
|
|
if value > threshold:
|
|
return f"Value {value} exceeds threshold {threshold}"
|
|
return True
|
|
|
|
def process(self, value, threshold):
|
|
# Create image based on value
|
|
intensity = value / 10.0
|
|
image = torch.ones([1, 512, 512, 3]) * intensity
|
|
return (image,)
|
|
|
|
|
|
class TestAsyncError(ComfyNodeABC):
|
|
"""Test node that errors during async execution."""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"value": (IO.ANY, {}),
|
|
"error_after": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = (IO.ANY,)
|
|
FUNCTION = "error_execution"
|
|
CATEGORY = "_for_testing/async"
|
|
|
|
async def error_execution(self, value, error_after):
|
|
await asyncio.sleep(error_after)
|
|
raise RuntimeError("Intentional async execution error for testing")
|
|
|
|
|
|
class TestAsyncValidationError(ComfyNodeABC):
|
|
"""Test node with async validation that always fails."""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"value": ("FLOAT", {"default": 5.0}),
|
|
"max_value": ("FLOAT", {"default": 10.0}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "process"
|
|
CATEGORY = "_for_testing/async"
|
|
|
|
@classmethod
|
|
async def VALIDATE_INPUTS(cls, value, max_value):
|
|
await asyncio.sleep(0.05)
|
|
# Always fail validation for values > max_value
|
|
if value > max_value:
|
|
return f"Async validation failed: {value} > {max_value}"
|
|
return True
|
|
|
|
def process(self, value, max_value):
|
|
# This won't be reached if validation fails
|
|
image = torch.ones([1, 512, 512, 3]) * (value / max_value)
|
|
return (image,)
|
|
|
|
|
|
class TestAsyncTimeout(ComfyNodeABC):
|
|
"""Test node that simulates timeout scenarios."""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"value": (IO.ANY, {}),
|
|
"timeout": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}),
|
|
"operation_time": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = (IO.ANY,)
|
|
FUNCTION = "timeout_execution"
|
|
CATEGORY = "_for_testing/async"
|
|
|
|
async def timeout_execution(self, value, timeout, operation_time):
|
|
try:
|
|
# This will timeout if operation_time > timeout
|
|
await asyncio.wait_for(asyncio.sleep(operation_time), timeout=timeout)
|
|
return (value,)
|
|
except asyncio.TimeoutError:
|
|
raise RuntimeError(f"Operation timed out after {timeout} seconds")
|
|
|
|
|
|
class TestSyncError(ComfyNodeABC):
|
|
"""Test node that errors synchronously (for mixed sync/async testing)."""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"value": (IO.ANY, {}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = (IO.ANY,)
|
|
FUNCTION = "sync_error"
|
|
CATEGORY = "_for_testing/async"
|
|
|
|
def sync_error(self, value):
|
|
raise RuntimeError("Intentional sync execution error for testing")
|
|
|
|
|
|
class TestAsyncLazyCheck(ComfyNodeABC):
|
|
"""Test node with async check_lazy_status."""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"input1": (IO.ANY, {"lazy": True}),
|
|
"input2": (IO.ANY, {"lazy": True}),
|
|
"condition": ("BOOLEAN", {"default": True}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "process"
|
|
CATEGORY = "_for_testing/async"
|
|
|
|
async def check_lazy_status(self, condition, input1, input2):
|
|
# Simulate async checking (e.g., querying remote service)
|
|
await asyncio.sleep(0.05)
|
|
|
|
needed = []
|
|
if condition and input1 is None:
|
|
needed.append("input1")
|
|
if not condition and input2 is None:
|
|
needed.append("input2")
|
|
return needed
|
|
|
|
def process(self, input1, input2, condition):
|
|
# Return a simple image
|
|
return (torch.ones([1, 512, 512, 3]),)
|
|
|
|
|
|
class TestDynamicAsyncGeneration(ComfyNodeABC):
|
|
"""Test node that dynamically generates async nodes."""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"image1": ("IMAGE",),
|
|
"image2": ("IMAGE",),
|
|
"num_async_nodes": ("INT", {"default": 3, "min": 1, "max": 10}),
|
|
"sleep_duration": ("FLOAT", {"default": 0.2, "min": 0.1, "max": 1.0}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "generate_async_workflow"
|
|
CATEGORY = "_for_testing/async"
|
|
|
|
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
|
|
g = GraphBuilder()
|
|
|
|
# Create multiple async sleep nodes
|
|
sleep_nodes = []
|
|
for i in range(num_async_nodes):
|
|
image = image1 if i % 2 == 0 else image2
|
|
sleep_node = g.node("TestSleep", value=image, seconds=sleep_duration)
|
|
sleep_nodes.append(sleep_node)
|
|
|
|
# Average all results
|
|
if len(sleep_nodes) == 1:
|
|
final_node = sleep_nodes[0]
|
|
else:
|
|
avg_inputs = {"input1": sleep_nodes[0].out(0)}
|
|
for i, node in enumerate(sleep_nodes[1:], 2):
|
|
avg_inputs[f"input{i}"] = node.out(0)
|
|
final_node = g.node("TestVariadicAverage", **avg_inputs)
|
|
|
|
return {
|
|
"result": (final_node.out(0),),
|
|
"expand": g.finalize(),
|
|
}
|
|
|
|
|
|
class TestAsyncResourceUser(ComfyNodeABC):
|
|
"""Test node that uses resources during async execution."""
|
|
|
|
# Class-level resource tracking for testing
|
|
_active_resources: Dict[str, bool] = {}
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"value": (IO.ANY, {}),
|
|
"resource_id": ("STRING", {"default": "resource_0"}),
|
|
"duration": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = (IO.ANY,)
|
|
FUNCTION = "use_resource"
|
|
CATEGORY = "_for_testing/async"
|
|
|
|
async def use_resource(self, value, resource_id, duration):
|
|
# Check if resource is already in use
|
|
if self._active_resources.get(resource_id, False):
|
|
raise RuntimeError(f"Resource {resource_id} is already in use!")
|
|
|
|
# Mark resource as in use
|
|
self._active_resources[resource_id] = True
|
|
|
|
try:
|
|
# Simulate resource usage
|
|
await asyncio.sleep(duration)
|
|
return (value,)
|
|
finally:
|
|
# Always clean up resource
|
|
self._active_resources[resource_id] = False
|
|
|
|
|
|
class TestAsyncBatchProcessing(ComfyNodeABC):
|
|
"""Test async processing of batched inputs."""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"images": ("IMAGE",),
|
|
"process_time_per_item": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 1.0}),
|
|
},
|
|
"hidden": {
|
|
"unique_id": "UNIQUE_ID",
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "process_batch"
|
|
CATEGORY = "_for_testing/async"
|
|
|
|
async def process_batch(self, images, process_time_per_item, unique_id):
|
|
batch_size = images.shape[0]
|
|
pbar = ProgressBar(batch_size, node_id=unique_id)
|
|
|
|
# Process each image in the batch
|
|
processed = []
|
|
for i in range(batch_size):
|
|
# Simulate async processing
|
|
await asyncio.sleep(process_time_per_item)
|
|
|
|
# Simple processing: invert the image
|
|
processed_image = 1.0 - images[i:i+1]
|
|
processed.append(processed_image)
|
|
|
|
pbar.update(1)
|
|
|
|
# Stack processed images
|
|
result = torch.cat(processed, dim=0)
|
|
return (result,)
|
|
|
|
|
|
class TestAsyncConcurrentLimit(ComfyNodeABC):
|
|
"""Test concurrent execution limits for async nodes."""
|
|
|
|
_semaphore = asyncio.Semaphore(2) # Only allow 2 concurrent executions
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"value": (IO.ANY, {}),
|
|
"duration": ("FLOAT", {"default": 0.5, "min": 0.1, "max": 2.0}),
|
|
"node_id": ("INT", {"default": 0}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = (IO.ANY,)
|
|
FUNCTION = "limited_execution"
|
|
CATEGORY = "_for_testing/async"
|
|
|
|
async def limited_execution(self, value, duration, node_id):
|
|
async with self._semaphore:
|
|
# Node {node_id} acquired semaphore
|
|
await asyncio.sleep(duration)
|
|
# Node {node_id} releasing semaphore
|
|
return (value,)
|
|
|
|
|
|
# Add node mappings
|
|
ASYNC_TEST_NODE_CLASS_MAPPINGS = {
|
|
"TestAsyncValidation": TestAsyncValidation,
|
|
"TestAsyncError": TestAsyncError,
|
|
"TestAsyncValidationError": TestAsyncValidationError,
|
|
"TestAsyncTimeout": TestAsyncTimeout,
|
|
"TestSyncError": TestSyncError,
|
|
"TestAsyncLazyCheck": TestAsyncLazyCheck,
|
|
"TestDynamicAsyncGeneration": TestDynamicAsyncGeneration,
|
|
"TestAsyncResourceUser": TestAsyncResourceUser,
|
|
"TestAsyncBatchProcessing": TestAsyncBatchProcessing,
|
|
"TestAsyncConcurrentLimit": TestAsyncConcurrentLimit,
|
|
}
|
|
|
|
ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"TestAsyncValidation": "Test Async Validation",
|
|
"TestAsyncError": "Test Async Error",
|
|
"TestAsyncValidationError": "Test Async Validation Error",
|
|
"TestAsyncTimeout": "Test Async Timeout",
|
|
"TestSyncError": "Test Sync Error",
|
|
"TestAsyncLazyCheck": "Test Async Lazy Check",
|
|
"TestDynamicAsyncGeneration": "Test Dynamic Async Generation",
|
|
"TestAsyncResourceUser": "Test Async Resource User",
|
|
"TestAsyncBatchProcessing": "Test Async Batch Processing",
|
|
"TestAsyncConcurrentLimit": "Test Async Concurrent Limit",
|
|
}
|