import pytest import time import torch import urllib.error import numpy as np import subprocess from pytest import fixture from comfy_execution.graph_utils import GraphBuilder from tests.inference.test_execution import ComfyClient @pytest.mark.execution class TestAsyncNodes: @fixture(scope="class", autouse=True, params=[ (False, 0), (True, 0), (True, 100), ]) def _server(self, args_pytest, request): pargs = [ 'python','main.py', '--output-directory', args_pytest["output_dir"], '--listen', args_pytest["listen"], '--port', str(args_pytest["port"]), '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', ] use_lru, lru_size = request.param if use_lru: pargs += ['--cache-lru', str(lru_size)] # Running server with args: pargs p = subprocess.Popen(pargs) yield p.kill() torch.cuda.empty_cache() @fixture(scope="class", autouse=True) def shared_client(self, args_pytest, _server): client = ComfyClient() n_tries = 5 for i in range(n_tries): time.sleep(4) try: client.connect(listen=args_pytest["listen"], port=args_pytest["port"]) except ConnectionRefusedError: # Retrying... pass else: break yield client del client torch.cuda.empty_cache() @fixture def client(self, shared_client, request): shared_client.set_test_name(f"async_nodes[{request.node.name}]") yield shared_client @fixture def builder(self, request): yield GraphBuilder(prefix=request.node.name) # Happy Path Tests def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder): """Test that a basic async node executes correctly.""" g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1) output = g.node("SaveImage", images=sleep_node.out(0)) result = client.run(g) # Verify execution completed assert result.did_run(sleep_node), "Async sleep node should have executed" assert result.did_run(output), "Output node should have executed" # Verify the image passed through correctly result_images = result.get_images(output) assert len(result_images) == 1, "Should have 1 image" assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black" def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder): """Test that multiple async nodes execute in parallel.""" g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) # Create multiple async sleep nodes with different durations sleep1 = g.node("TestSleep", value=image.out(0), seconds=0.3) sleep2 = g.node("TestSleep", value=image.out(0), seconds=0.4) sleep3 = g.node("TestSleep", value=image.out(0), seconds=0.5) # Add outputs for each _output1 = g.node("PreviewImage", images=sleep1.out(0)) _output2 = g.node("PreviewImage", images=sleep2.out(0)) _output3 = g.node("PreviewImage", images=sleep3.out(0)) start_time = time.time() result = client.run(g) elapsed_time = time.time() - start_time # Should take ~0.5s (max duration) not 1.2s (sum of durations) assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s" # Verify all nodes executed assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3) def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder): """Test async nodes with proper dependency handling.""" g = builder image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) # Chain of async operations sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2) sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2) # Average depends on both async results average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0)) output = g.node("SaveImage", images=average.out(0)) result = client.run(g) # Verify execution order assert result.did_run(sleep1) and result.did_run(sleep2) assert result.did_run(average) and result.did_run(output) # Verify averaged result result_images = result.get_images(output) avg_value = np.array(result_images[0]).mean() assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5" def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder): """Test async VALIDATE_INPUTS function.""" g = builder # Create a test node with async validation validation_node = g.node("TestAsyncValidation", value=5.0, threshold=10.0) g.node("SaveImage", images=validation_node.out(0)) # Should pass validation result = client.run(g) assert result.did_run(validation_node) # Test validation failure validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold with pytest.raises(urllib.error.HTTPError): client.run(g) def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder): """Test async nodes with lazy evaluation.""" g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) # Create async nodes that will be evaluated lazily sleep1 = g.node("TestSleep", value=input1.out(0), seconds=0.3) sleep2 = g.node("TestSleep", value=input2.out(0), seconds=0.3) # Use lazy mix that only needs sleep1 (mask=0.0) lazy_mix = g.node("TestLazyMixImages", image1=sleep1.out(0), image2=sleep2.out(0), mask=mask.out(0)) g.node("SaveImage", images=lazy_mix.out(0)) start_time = time.time() result = client.run(g) elapsed_time = time.time() - start_time # Should only execute sleep1, not sleep2 assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s" assert result.did_run(sleep1), "Sleep1 should have executed" assert not result.did_run(sleep2), "Sleep2 should have been skipped" def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder): """Test async check_lazy_status function.""" g = builder # Create a node with async check_lazy_status lazy_node = g.node("TestAsyncLazyCheck", input1="value1", input2="value2", condition=True) g.node("SaveImage", images=lazy_node.out(0)) result = client.run(g) assert result.did_run(lazy_node) # Error Handling Tests def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder): """Test that async execution errors are properly handled.""" g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) # Create an async node that will error error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1) g.node("SaveImage", images=error_node.out(0)) try: client.run(g) assert False, "Should have raised an error" except Exception as e: assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}" assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node" def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder): """Test async validation error handling.""" g = builder # Node with async validation that will fail validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0) g.node("SaveImage", images=validation_node.out(0)) with pytest.raises(urllib.error.HTTPError) as exc_info: client.run(g) # Verify it's a validation error assert exc_info.value.code == 400 def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder): """Test handling of async operations that timeout.""" g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) # Very long sleep that would timeout timeout_node = g.node("TestAsyncTimeout", value=image.out(0), timeout=0.5, operation_time=2.0) g.node("SaveImage", images=timeout_node.out(0)) try: client.run(g) assert False, "Should have raised a timeout error" except Exception as e: assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}" def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder): """Test that workflow can recover after async errors.""" g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) # First run with error error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1) g.node("SaveImage", images=error_node.out(0)) try: client.run(g) except Exception: pass # Expected # Second run should succeed g2 = GraphBuilder(prefix="recovery_test") image2 = g2.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1) g2.node("SaveImage", images=sleep_node.out(0)) result = client.run(g2) assert result.did_run(sleep_node), "Should be able to run after error" def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder): """Test handling when sync node errors while async node is executing.""" g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) # Async node that takes time sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.5) # Sync node that will error immediately error_node = g.node("TestSyncError", value=image.out(0)) # Both feed into output g.node("PreviewImage", images=sleep_node.out(0)) g.node("PreviewImage", images=error_node.out(0)) try: client.run(g) assert False, "Should have raised an error" except Exception as e: # Verify the sync error was caught even though async was running assert 'prompt_id' in e.args[0] # Edge Cases def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder): """Test async nodes with execution blockers.""" g = builder image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) # Async sleep nodes sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2) sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2) # Create list of images image_list = g.node("TestMakeListNode", value1=sleep1.out(0), value2=sleep2.out(0)) # Create list of blocking conditions - [False, True] to block only the second item int1 = g.node("StubInt", value=1) int2 = g.node("StubInt", value=2) block_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0)) # Compare each value against 2, so first is False (1 != 2) and second is True (2 == 2) compare = g.node("TestIntConditions", a=block_list.out(0), b=2, operation="==") # Block based on the comparison results blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False) output = g.node("PreviewImage", images=blocker.out(0)) result = client.run(g) images = result.get_images(output) assert len(images) == 1, "Should have blocked second image" def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder): """Test that async nodes are properly cached.""" g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2) g.node("SaveImage", images=sleep_node.out(0)) # First run result1 = client.run(g) assert result1.did_run(sleep_node), "Should run first time" # Second run - should be cached start_time = time.time() result2 = client.run(g) elapsed_time = time.time() - start_time assert not result2.did_run(sleep_node), "Should be cached" assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant" def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder): """Test async nodes within dynamically generated prompts.""" g = builder image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) # Node that generates async nodes dynamically dynamic_async = g.node("TestDynamicAsyncGeneration", image1=image1.out(0), image2=image2.out(0), num_async_nodes=3, sleep_duration=0.2) g.node("SaveImage", images=dynamic_async.out(0)) start_time = time.time() result = client.run(g) elapsed_time = time.time() - start_time # Should execute async nodes in parallel within dynamic prompt assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s" assert result.did_run(dynamic_async) def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder): """Test that async resources are properly cleaned up.""" g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) # Create multiple async nodes that use resources resource_nodes = [] for i in range(5): node = g.node("TestAsyncResourceUser", value=image.out(0), resource_id=f"resource_{i}", duration=0.1) resource_nodes.append(node) g.node("PreviewImage", images=node.out(0)) result = client.run(g) # Verify all nodes executed for node in resource_nodes: assert result.did_run(node) # Run again to ensure resources were cleaned up result2 = client.run(g) # Should be cached but not error due to resource conflicts for node in resource_nodes: assert not result2.did_run(node), "Should be cached" def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder): """Test cancellation of async operations.""" # This would require implementing cancellation in the client # For now, we'll test that long-running async operations can be interrupted pass # TODO: Implement when cancellation API is available def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder): """Test workflows with both sync and async nodes.""" g = builder image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) # Mix of sync and async operations # Sync: lazy mix images sync_op1 = g.node("TestLazyMixImages", image1=image1.out(0), image2=image2.out(0), mask=mask.out(0)) # Async: sleep async_op1 = g.node("TestSleep", value=sync_op1.out(0), seconds=0.2) # Sync: custom validation sync_op2 = g.node("TestCustomValidation1", input1=async_op1.out(0), input2=0.5) # Async: sleep again async_op2 = g.node("TestSleep", value=sync_op2.out(0), seconds=0.2) output = g.node("SaveImage", images=async_op2.out(0)) result = client.run(g) # Verify all nodes executed in correct order assert result.did_run(sync_op1) assert result.did_run(async_op1) assert result.did_run(sync_op2) assert result.did_run(async_op2) # Image should be a mix of black and white (gray) result_images = result.get_images(output) avg_value = np.array(result_images[0]).mean() assert abs(avg_value - 63.75) < 5, f"Average value {avg_value} should be ~63.75"