diff --git a/execution.py b/execution.py index cfc46d0b8..90cefc023 100644 --- a/execution.py +++ b/execution.py @@ -343,7 +343,17 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, input_data_all = None try: if unique_id in pending_async_nodes: - results = [r.result() if isinstance(r, asyncio.Task) else r for r in pending_async_nodes[unique_id]] + results = [] + for r in pending_async_nodes[unique_id]: + if isinstance(r, asyncio.Task): + try: + results.append(r.result()) + except Exception as ex: + # An async task failed - propagate the exception up + del pending_async_nodes[unique_id] + raise ex + else: + results.append(r) del pending_async_nodes[unique_id] output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def) elif unique_id in pending_subgraph_results: @@ -418,7 +428,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, unblock = execution_list.add_external_block(unique_id) async def await_completion(): tasks = [x for x in output_data if isinstance(x, asyncio.Task)] - await asyncio.gather(*tasks) + await asyncio.gather(*tasks, return_exceptions=True) unblock() asyncio.create_task(await_completion()) return (ExecutionResult.PENDING, None, None) diff --git a/tests/inference/test_async_nodes.py b/tests/inference/test_async_nodes.py new file mode 100644 index 000000000..b243bbca9 --- /dev/null +++ b/tests/inference/test_async_nodes.py @@ -0,0 +1,410 @@ +import pytest +import time +import torch +import urllib.error +import numpy as np +import subprocess + +from pytest import fixture +from comfy_execution.graph_utils import GraphBuilder +from tests.inference.test_execution import ComfyClient + + +@pytest.mark.execution +class TestAsyncNodes: + @fixture(scope="class", autouse=True, params=[ + (False, 0), + (True, 0), + (True, 100), + ]) + def _server(self, args_pytest, request): + pargs = [ + 'python','main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', + ] + use_lru, lru_size = request.param + if use_lru: + pargs += ['--cache-lru', str(lru_size)] + # Running server with args: pargs + p = subprocess.Popen(pargs) + yield + p.kill() + torch.cuda.empty_cache() + + @fixture(scope="class", autouse=True) + def shared_client(self, args_pytest, _server): + client = ComfyClient() + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + client.connect(listen=args_pytest["listen"], port=args_pytest["port"]) + except ConnectionRefusedError: + # Retrying... + pass + else: + break + yield client + del client + torch.cuda.empty_cache() + + @fixture + def client(self, shared_client, request): + shared_client.set_test_name(f"async_nodes[{request.node.name}]") + yield shared_client + + @fixture + def builder(self, request): + yield GraphBuilder(prefix=request.node.name) + + # Happy Path Tests + + def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder): + """Test that a basic async node executes correctly.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1) + output = g.node("SaveImage", images=sleep_node.out(0)) + + result = client.run(g) + + # Verify execution completed + assert result.did_run(sleep_node), "Async sleep node should have executed" + assert result.did_run(output), "Output node should have executed" + + # Verify the image passed through correctly + result_images = result.get_images(output) + assert len(result_images) == 1, "Should have 1 image" + assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black" + + def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder): + """Test that multiple async nodes execute in parallel.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # Create multiple async sleep nodes with different durations + sleep1 = g.node("TestSleep", value=image.out(0), seconds=0.3) + sleep2 = g.node("TestSleep", value=image.out(0), seconds=0.4) + sleep3 = g.node("TestSleep", value=image.out(0), seconds=0.5) + + # Add outputs for each + _output1 = g.node("PreviewImage", images=sleep1.out(0)) + _output2 = g.node("PreviewImage", images=sleep2.out(0)) + _output3 = g.node("PreviewImage", images=sleep3.out(0)) + + start_time = time.time() + result = client.run(g) + elapsed_time = time.time() - start_time + + # Should take ~0.5s (max duration) not 1.2s (sum of durations) + assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s" + + # Verify all nodes executed + assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3) + + def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder): + """Test async nodes with proper dependency handling.""" + g = builder + image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Chain of async operations + sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2) + sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2) + + # Average depends on both async results + average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0)) + output = g.node("SaveImage", images=average.out(0)) + + result = client.run(g) + + # Verify execution order + assert result.did_run(sleep1) and result.did_run(sleep2) + assert result.did_run(average) and result.did_run(output) + + # Verify averaged result + result_images = result.get_images(output) + avg_value = np.array(result_images[0]).mean() + assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5" + + def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder): + """Test async VALIDATE_INPUTS function.""" + g = builder + # Create a test node with async validation + validation_node = g.node("TestAsyncValidation", value=5.0, threshold=10.0) + g.node("SaveImage", images=validation_node.out(0)) + + # Should pass validation + result = client.run(g) + assert result.did_run(validation_node) + + # Test validation failure + validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold + with pytest.raises(urllib.error.HTTPError): + client.run(g) + + def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder): + """Test async nodes with lazy evaluation.""" + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) + + # Create async nodes that will be evaluated lazily + sleep1 = g.node("TestSleep", value=input1.out(0), seconds=0.3) + sleep2 = g.node("TestSleep", value=input2.out(0), seconds=0.3) + + # Use lazy mix that only needs sleep1 (mask=0.0) + lazy_mix = g.node("TestLazyMixImages", image1=sleep1.out(0), image2=sleep2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix.out(0)) + + start_time = time.time() + result = client.run(g) + elapsed_time = time.time() - start_time + + # Should only execute sleep1, not sleep2 + assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s" + assert result.did_run(sleep1), "Sleep1 should have executed" + assert not result.did_run(sleep2), "Sleep2 should have been skipped" + + def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder): + """Test async check_lazy_status function.""" + g = builder + # Create a node with async check_lazy_status + lazy_node = g.node("TestAsyncLazyCheck", + input1="value1", + input2="value2", + condition=True) + g.node("SaveImage", images=lazy_node.out(0)) + + result = client.run(g) + assert result.did_run(lazy_node) + + # Error Handling Tests + + def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder): + """Test that async execution errors are properly handled.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + # Create an async node that will error + error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1) + g.node("SaveImage", images=error_node.out(0)) + + try: + client.run(g) + assert False, "Should have raised an error" + except Exception as e: + assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}" + assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node" + + def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder): + """Test async validation error handling.""" + g = builder + # Node with async validation that will fail + validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0) + g.node("SaveImage", images=validation_node.out(0)) + + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.run(g) + # Verify it's a validation error + assert exc_info.value.code == 400 + + def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder): + """Test handling of async operations that timeout.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + # Very long sleep that would timeout + timeout_node = g.node("TestAsyncTimeout", value=image.out(0), timeout=0.5, operation_time=2.0) + g.node("SaveImage", images=timeout_node.out(0)) + + try: + client.run(g) + assert False, "Should have raised a timeout error" + except Exception as e: + assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}" + + def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder): + """Test that workflow can recover after async errors.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # First run with error + error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1) + g.node("SaveImage", images=error_node.out(0)) + + try: + client.run(g) + except Exception: + pass # Expected + + # Second run should succeed + g2 = GraphBuilder(prefix="recovery_test") + image2 = g2.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1) + g2.node("SaveImage", images=sleep_node.out(0)) + + result = client.run(g2) + assert result.did_run(sleep_node), "Should be able to run after error" + + def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder): + """Test handling when sync node errors while async node is executing.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # Async node that takes time + sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.5) + + # Sync node that will error immediately + error_node = g.node("TestSyncError", value=image.out(0)) + + # Both feed into output + g.node("PreviewImage", images=sleep_node.out(0)) + g.node("PreviewImage", images=error_node.out(0)) + + try: + client.run(g) + assert False, "Should have raised an error" + except Exception as e: + # Verify the sync error was caught even though async was running + assert 'prompt_id' in e.args[0] + + # Edge Cases + + def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder): + """Test async nodes with execution blockers.""" + g = builder + image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Async sleep nodes + sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2) + sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2) + + # Create list of images + image_list = g.node("TestMakeListNode", value1=sleep1.out(0), value2=sleep2.out(0)) + + # Create list of blocking conditions - [False, True] to block only the second item + int1 = g.node("StubInt", value=1) + int2 = g.node("StubInt", value=2) + block_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0)) + + # Compare each value against 2, so first is False (1 != 2) and second is True (2 == 2) + compare = g.node("TestIntConditions", a=block_list.out(0), b=2, operation="==") + + # Block based on the comparison results + blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False) + + output = g.node("PreviewImage", images=blocker.out(0)) + + result = client.run(g) + images = result.get_images(output) + assert len(images) == 1, "Should have blocked second image" + + def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder): + """Test that async nodes are properly cached.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2) + g.node("SaveImage", images=sleep_node.out(0)) + + # First run + result1 = client.run(g) + assert result1.did_run(sleep_node), "Should run first time" + + # Second run - should be cached + start_time = time.time() + result2 = client.run(g) + elapsed_time = time.time() - start_time + + assert not result2.did_run(sleep_node), "Should be cached" + assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant" + + def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder): + """Test async nodes within dynamically generated prompts.""" + g = builder + image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Node that generates async nodes dynamically + dynamic_async = g.node("TestDynamicAsyncGeneration", + image1=image1.out(0), + image2=image2.out(0), + num_async_nodes=3, + sleep_duration=0.2) + g.node("SaveImage", images=dynamic_async.out(0)) + + start_time = time.time() + result = client.run(g) + elapsed_time = time.time() - start_time + + # Should execute async nodes in parallel within dynamic prompt + assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s" + assert result.did_run(dynamic_async) + + def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder): + """Test that async resources are properly cleaned up.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # Create multiple async nodes that use resources + resource_nodes = [] + for i in range(5): + node = g.node("TestAsyncResourceUser", + value=image.out(0), + resource_id=f"resource_{i}", + duration=0.1) + resource_nodes.append(node) + g.node("PreviewImage", images=node.out(0)) + + result = client.run(g) + + # Verify all nodes executed + for node in resource_nodes: + assert result.did_run(node) + + # Run again to ensure resources were cleaned up + result2 = client.run(g) + # Should be cached but not error due to resource conflicts + for node in resource_nodes: + assert not result2.did_run(node), "Should be cached" + + def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder): + """Test cancellation of async operations.""" + # This would require implementing cancellation in the client + # For now, we'll test that long-running async operations can be interrupted + pass # TODO: Implement when cancellation API is available + + def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder): + """Test workflows with both sync and async nodes.""" + g = builder + image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + # Mix of sync and async operations + # Sync: lazy mix images + sync_op1 = g.node("TestLazyMixImages", image1=image1.out(0), image2=image2.out(0), mask=mask.out(0)) + # Async: sleep + async_op1 = g.node("TestSleep", value=sync_op1.out(0), seconds=0.2) + # Sync: custom validation + sync_op2 = g.node("TestCustomValidation1", input1=async_op1.out(0), input2=0.5) + # Async: sleep again + async_op2 = g.node("TestSleep", value=sync_op2.out(0), seconds=0.2) + + output = g.node("SaveImage", images=async_op2.out(0)) + + result = client.run(g) + + # Verify all nodes executed in correct order + assert result.did_run(sync_op1) + assert result.did_run(async_op1) + assert result.did_run(sync_op2) + assert result.did_run(async_op2) + + # Image should be a mix of black and white (gray) + result_images = result.get_images(output) + avg_value = np.array(result_images[0]).mean() + assert abs(avg_value - 63.75) < 5, f"Average value {avg_value} should be ~63.75" diff --git a/tests/inference/testing_nodes/testing-pack/__init__.py b/tests/inference/testing_nodes/testing-pack/__init__.py index dcc71659a..0b16028ec 100644 --- a/tests/inference/testing_nodes/testing-pack/__init__.py +++ b/tests/inference/testing_nodes/testing-pack/__init__.py @@ -3,6 +3,7 @@ from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DI from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS +from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS # NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS) # NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS) @@ -13,6 +14,7 @@ NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS = {} NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS) @@ -20,4 +22,5 @@ NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS) diff --git a/tests/inference/testing_nodes/testing-pack/async_test_nodes.py b/tests/inference/testing_nodes/testing-pack/async_test_nodes.py new file mode 100644 index 000000000..547eea6f4 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/async_test_nodes.py @@ -0,0 +1,343 @@ +import torch +import asyncio +from typing import Dict +from comfy.utils import ProgressBar +from comfy_execution.graph_utils import GraphBuilder +from comfy.comfy_types.node_typing import ComfyNodeABC +from comfy.comfy_types import IO + + +class TestAsyncValidation(ComfyNodeABC): + """Test node with async VALIDATE_INPUTS.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 5.0}), + "threshold": ("FLOAT", {"default": 10.0}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + CATEGORY = "_for_testing/async" + + @classmethod + async def VALIDATE_INPUTS(cls, value, threshold): + # Simulate async validation (e.g., checking remote service) + await asyncio.sleep(0.05) + + if value > threshold: + return f"Value {value} exceeds threshold {threshold}" + return True + + def process(self, value, threshold): + # Create image based on value + intensity = value / 10.0 + image = torch.ones([1, 512, 512, 3]) * intensity + return (image,) + + +class TestAsyncError(ComfyNodeABC): + """Test node that errors during async execution.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + "error_after": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "error_execution" + CATEGORY = "_for_testing/async" + + async def error_execution(self, value, error_after): + await asyncio.sleep(error_after) + raise RuntimeError("Intentional async execution error for testing") + + +class TestAsyncValidationError(ComfyNodeABC): + """Test node with async validation that always fails.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 5.0}), + "max_value": ("FLOAT", {"default": 10.0}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + CATEGORY = "_for_testing/async" + + @classmethod + async def VALIDATE_INPUTS(cls, value, max_value): + await asyncio.sleep(0.05) + # Always fail validation for values > max_value + if value > max_value: + return f"Async validation failed: {value} > {max_value}" + return True + + def process(self, value, max_value): + # This won't be reached if validation fails + image = torch.ones([1, 512, 512, 3]) * (value / max_value) + return (image,) + + +class TestAsyncTimeout(ComfyNodeABC): + """Test node that simulates timeout scenarios.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + "timeout": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}), + "operation_time": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "timeout_execution" + CATEGORY = "_for_testing/async" + + async def timeout_execution(self, value, timeout, operation_time): + try: + # This will timeout if operation_time > timeout + await asyncio.wait_for(asyncio.sleep(operation_time), timeout=timeout) + return (value,) + except asyncio.TimeoutError: + raise RuntimeError(f"Operation timed out after {timeout} seconds") + + +class TestSyncError(ComfyNodeABC): + """Test node that errors synchronously (for mixed sync/async testing).""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "sync_error" + CATEGORY = "_for_testing/async" + + def sync_error(self, value): + raise RuntimeError("Intentional sync execution error for testing") + + +class TestAsyncLazyCheck(ComfyNodeABC): + """Test node with async check_lazy_status.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": (IO.ANY, {"lazy": True}), + "input2": (IO.ANY, {"lazy": True}), + "condition": ("BOOLEAN", {"default": True}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + CATEGORY = "_for_testing/async" + + async def check_lazy_status(self, condition, input1, input2): + # Simulate async checking (e.g., querying remote service) + await asyncio.sleep(0.05) + + needed = [] + if condition and input1 is None: + needed.append("input1") + if not condition and input2 is None: + needed.append("input2") + return needed + + def process(self, input1, input2, condition): + # Return a simple image + return (torch.ones([1, 512, 512, 3]),) + + +class TestDynamicAsyncGeneration(ComfyNodeABC): + """Test node that dynamically generates async nodes.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image1": ("IMAGE",), + "image2": ("IMAGE",), + "num_async_nodes": ("INT", {"default": 3, "min": 1, "max": 10}), + "sleep_duration": ("FLOAT", {"default": 0.2, "min": 0.1, "max": 1.0}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "generate_async_workflow" + CATEGORY = "_for_testing/async" + + def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration): + g = GraphBuilder() + + # Create multiple async sleep nodes + sleep_nodes = [] + for i in range(num_async_nodes): + image = image1 if i % 2 == 0 else image2 + sleep_node = g.node("TestSleep", value=image, seconds=sleep_duration) + sleep_nodes.append(sleep_node) + + # Average all results + if len(sleep_nodes) == 1: + final_node = sleep_nodes[0] + else: + avg_inputs = {"input1": sleep_nodes[0].out(0)} + for i, node in enumerate(sleep_nodes[1:], 2): + avg_inputs[f"input{i}"] = node.out(0) + final_node = g.node("TestVariadicAverage", **avg_inputs) + + return { + "result": (final_node.out(0),), + "expand": g.finalize(), + } + + +class TestAsyncResourceUser(ComfyNodeABC): + """Test node that uses resources during async execution.""" + + # Class-level resource tracking for testing + _active_resources: Dict[str, bool] = {} + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + "resource_id": ("STRING", {"default": "resource_0"}), + "duration": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "use_resource" + CATEGORY = "_for_testing/async" + + async def use_resource(self, value, resource_id, duration): + # Check if resource is already in use + if self._active_resources.get(resource_id, False): + raise RuntimeError(f"Resource {resource_id} is already in use!") + + # Mark resource as in use + self._active_resources[resource_id] = True + + try: + # Simulate resource usage + await asyncio.sleep(duration) + return (value,) + finally: + # Always clean up resource + self._active_resources[resource_id] = False + + +class TestAsyncBatchProcessing(ComfyNodeABC): + """Test async processing of batched inputs.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ("IMAGE",), + "process_time_per_item": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 1.0}), + }, + "hidden": { + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process_batch" + CATEGORY = "_for_testing/async" + + async def process_batch(self, images, process_time_per_item, unique_id): + batch_size = images.shape[0] + pbar = ProgressBar(batch_size, node_id=unique_id) + + # Process each image in the batch + processed = [] + for i in range(batch_size): + # Simulate async processing + await asyncio.sleep(process_time_per_item) + + # Simple processing: invert the image + processed_image = 1.0 - images[i:i+1] + processed.append(processed_image) + + pbar.update(1) + + # Stack processed images + result = torch.cat(processed, dim=0) + return (result,) + + +class TestAsyncConcurrentLimit(ComfyNodeABC): + """Test concurrent execution limits for async nodes.""" + + _semaphore = asyncio.Semaphore(2) # Only allow 2 concurrent executions + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + "duration": ("FLOAT", {"default": 0.5, "min": 0.1, "max": 2.0}), + "node_id": ("INT", {"default": 0}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "limited_execution" + CATEGORY = "_for_testing/async" + + async def limited_execution(self, value, duration, node_id): + async with self._semaphore: + # Node {node_id} acquired semaphore + await asyncio.sleep(duration) + # Node {node_id} releasing semaphore + return (value,) + + +# Add node mappings +ASYNC_TEST_NODE_CLASS_MAPPINGS = { + "TestAsyncValidation": TestAsyncValidation, + "TestAsyncError": TestAsyncError, + "TestAsyncValidationError": TestAsyncValidationError, + "TestAsyncTimeout": TestAsyncTimeout, + "TestSyncError": TestSyncError, + "TestAsyncLazyCheck": TestAsyncLazyCheck, + "TestDynamicAsyncGeneration": TestDynamicAsyncGeneration, + "TestAsyncResourceUser": TestAsyncResourceUser, + "TestAsyncBatchProcessing": TestAsyncBatchProcessing, + "TestAsyncConcurrentLimit": TestAsyncConcurrentLimit, +} + +ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS = { + "TestAsyncValidation": "Test Async Validation", + "TestAsyncError": "Test Async Error", + "TestAsyncValidationError": "Test Async Validation Error", + "TestAsyncTimeout": "Test Async Timeout", + "TestSyncError": "Test Sync Error", + "TestAsyncLazyCheck": "Test Async Lazy Check", + "TestDynamicAsyncGeneration": "Test Dynamic Async Generation", + "TestAsyncResourceUser": "Test Async Resource User", + "TestAsyncBatchProcessing": "Test Async Batch Processing", + "TestAsyncConcurrentLimit": "Test Async Concurrent Limit", +}