mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 08:16:44 +00:00
Add additional tests for async error cases
Also fixes one bug that was found when an async function throws an error after being scheduled on a task.
This commit is contained in:
parent
92f9a10782
commit
0254d9cc11
14
execution.py
14
execution.py
@ -343,7 +343,17 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
input_data_all = None
|
input_data_all = None
|
||||||
try:
|
try:
|
||||||
if unique_id in pending_async_nodes:
|
if unique_id in pending_async_nodes:
|
||||||
results = [r.result() if isinstance(r, asyncio.Task) else r for r in pending_async_nodes[unique_id]]
|
results = []
|
||||||
|
for r in pending_async_nodes[unique_id]:
|
||||||
|
if isinstance(r, asyncio.Task):
|
||||||
|
try:
|
||||||
|
results.append(r.result())
|
||||||
|
except Exception as ex:
|
||||||
|
# An async task failed - propagate the exception up
|
||||||
|
del pending_async_nodes[unique_id]
|
||||||
|
raise ex
|
||||||
|
else:
|
||||||
|
results.append(r)
|
||||||
del pending_async_nodes[unique_id]
|
del pending_async_nodes[unique_id]
|
||||||
output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def)
|
output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def)
|
||||||
elif unique_id in pending_subgraph_results:
|
elif unique_id in pending_subgraph_results:
|
||||||
@ -418,7 +428,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
unblock = execution_list.add_external_block(unique_id)
|
unblock = execution_list.add_external_block(unique_id)
|
||||||
async def await_completion():
|
async def await_completion():
|
||||||
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
|
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
unblock()
|
unblock()
|
||||||
asyncio.create_task(await_completion())
|
asyncio.create_task(await_completion())
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return (ExecutionResult.PENDING, None, None)
|
||||||
|
410
tests/inference/test_async_nodes.py
Normal file
410
tests/inference/test_async_nodes.py
Normal file
@ -0,0 +1,410 @@
|
|||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import urllib.error
|
||||||
|
import numpy as np
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from pytest import fixture
|
||||||
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
|
from tests.inference.test_execution import ComfyClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.execution
|
||||||
|
class TestAsyncNodes:
|
||||||
|
@fixture(scope="class", autouse=True, params=[
|
||||||
|
(False, 0),
|
||||||
|
(True, 0),
|
||||||
|
(True, 100),
|
||||||
|
])
|
||||||
|
def _server(self, args_pytest, request):
|
||||||
|
pargs = [
|
||||||
|
'python','main.py',
|
||||||
|
'--output-directory', args_pytest["output_dir"],
|
||||||
|
'--listen', args_pytest["listen"],
|
||||||
|
'--port', str(args_pytest["port"]),
|
||||||
|
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
||||||
|
]
|
||||||
|
use_lru, lru_size = request.param
|
||||||
|
if use_lru:
|
||||||
|
pargs += ['--cache-lru', str(lru_size)]
|
||||||
|
# Running server with args: pargs
|
||||||
|
p = subprocess.Popen(pargs)
|
||||||
|
yield
|
||||||
|
p.kill()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@fixture(scope="class", autouse=True)
|
||||||
|
def shared_client(self, args_pytest, _server):
|
||||||
|
client = ComfyClient()
|
||||||
|
n_tries = 5
|
||||||
|
for i in range(n_tries):
|
||||||
|
time.sleep(4)
|
||||||
|
try:
|
||||||
|
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
|
||||||
|
except ConnectionRefusedError:
|
||||||
|
# Retrying...
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
yield client
|
||||||
|
del client
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@fixture
|
||||||
|
def client(self, shared_client, request):
|
||||||
|
shared_client.set_test_name(f"async_nodes[{request.node.name}]")
|
||||||
|
yield shared_client
|
||||||
|
|
||||||
|
@fixture
|
||||||
|
def builder(self, request):
|
||||||
|
yield GraphBuilder(prefix=request.node.name)
|
||||||
|
|
||||||
|
# Happy Path Tests
|
||||||
|
|
||||||
|
def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test that a basic async node executes correctly."""
|
||||||
|
g = builder
|
||||||
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1)
|
||||||
|
output = g.node("SaveImage", images=sleep_node.out(0))
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
|
||||||
|
# Verify execution completed
|
||||||
|
assert result.did_run(sleep_node), "Async sleep node should have executed"
|
||||||
|
assert result.did_run(output), "Output node should have executed"
|
||||||
|
|
||||||
|
# Verify the image passed through correctly
|
||||||
|
result_images = result.get_images(output)
|
||||||
|
assert len(result_images) == 1, "Should have 1 image"
|
||||||
|
assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black"
|
||||||
|
|
||||||
|
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test that multiple async nodes execute in parallel."""
|
||||||
|
g = builder
|
||||||
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Create multiple async sleep nodes with different durations
|
||||||
|
sleep1 = g.node("TestSleep", value=image.out(0), seconds=0.3)
|
||||||
|
sleep2 = g.node("TestSleep", value=image.out(0), seconds=0.4)
|
||||||
|
sleep3 = g.node("TestSleep", value=image.out(0), seconds=0.5)
|
||||||
|
|
||||||
|
# Add outputs for each
|
||||||
|
_output1 = g.node("PreviewImage", images=sleep1.out(0))
|
||||||
|
_output2 = g.node("PreviewImage", images=sleep2.out(0))
|
||||||
|
_output3 = g.node("PreviewImage", images=sleep3.out(0))
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
result = client.run(g)
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Should take ~0.5s (max duration) not 1.2s (sum of durations)
|
||||||
|
assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s"
|
||||||
|
|
||||||
|
# Verify all nodes executed
|
||||||
|
assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3)
|
||||||
|
|
||||||
|
def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test async nodes with proper dependency handling."""
|
||||||
|
g = builder
|
||||||
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Chain of async operations
|
||||||
|
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
|
||||||
|
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
|
||||||
|
|
||||||
|
# Average depends on both async results
|
||||||
|
average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0))
|
||||||
|
output = g.node("SaveImage", images=average.out(0))
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
|
||||||
|
# Verify execution order
|
||||||
|
assert result.did_run(sleep1) and result.did_run(sleep2)
|
||||||
|
assert result.did_run(average) and result.did_run(output)
|
||||||
|
|
||||||
|
# Verify averaged result
|
||||||
|
result_images = result.get_images(output)
|
||||||
|
avg_value = np.array(result_images[0]).mean()
|
||||||
|
assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5"
|
||||||
|
|
||||||
|
def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test async VALIDATE_INPUTS function."""
|
||||||
|
g = builder
|
||||||
|
# Create a test node with async validation
|
||||||
|
validation_node = g.node("TestAsyncValidation", value=5.0, threshold=10.0)
|
||||||
|
g.node("SaveImage", images=validation_node.out(0))
|
||||||
|
|
||||||
|
# Should pass validation
|
||||||
|
result = client.run(g)
|
||||||
|
assert result.did_run(validation_node)
|
||||||
|
|
||||||
|
# Test validation failure
|
||||||
|
validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold
|
||||||
|
with pytest.raises(urllib.error.HTTPError):
|
||||||
|
client.run(g)
|
||||||
|
|
||||||
|
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test async nodes with lazy evaluation."""
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Create async nodes that will be evaluated lazily
|
||||||
|
sleep1 = g.node("TestSleep", value=input1.out(0), seconds=0.3)
|
||||||
|
sleep2 = g.node("TestSleep", value=input2.out(0), seconds=0.3)
|
||||||
|
|
||||||
|
# Use lazy mix that only needs sleep1 (mask=0.0)
|
||||||
|
lazy_mix = g.node("TestLazyMixImages", image1=sleep1.out(0), image2=sleep2.out(0), mask=mask.out(0))
|
||||||
|
g.node("SaveImage", images=lazy_mix.out(0))
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
result = client.run(g)
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Should only execute sleep1, not sleep2
|
||||||
|
assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s"
|
||||||
|
assert result.did_run(sleep1), "Sleep1 should have executed"
|
||||||
|
assert not result.did_run(sleep2), "Sleep2 should have been skipped"
|
||||||
|
|
||||||
|
def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test async check_lazy_status function."""
|
||||||
|
g = builder
|
||||||
|
# Create a node with async check_lazy_status
|
||||||
|
lazy_node = g.node("TestAsyncLazyCheck",
|
||||||
|
input1="value1",
|
||||||
|
input2="value2",
|
||||||
|
condition=True)
|
||||||
|
g.node("SaveImage", images=lazy_node.out(0))
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
assert result.did_run(lazy_node)
|
||||||
|
|
||||||
|
# Error Handling Tests
|
||||||
|
|
||||||
|
def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test that async execution errors are properly handled."""
|
||||||
|
g = builder
|
||||||
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
# Create an async node that will error
|
||||||
|
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
|
||||||
|
g.node("SaveImage", images=error_node.out(0))
|
||||||
|
|
||||||
|
try:
|
||||||
|
client.run(g)
|
||||||
|
assert False, "Should have raised an error"
|
||||||
|
except Exception as e:
|
||||||
|
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
|
||||||
|
assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node"
|
||||||
|
|
||||||
|
def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test async validation error handling."""
|
||||||
|
g = builder
|
||||||
|
# Node with async validation that will fail
|
||||||
|
validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0)
|
||||||
|
g.node("SaveImage", images=validation_node.out(0))
|
||||||
|
|
||||||
|
with pytest.raises(urllib.error.HTTPError) as exc_info:
|
||||||
|
client.run(g)
|
||||||
|
# Verify it's a validation error
|
||||||
|
assert exc_info.value.code == 400
|
||||||
|
|
||||||
|
def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test handling of async operations that timeout."""
|
||||||
|
g = builder
|
||||||
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
# Very long sleep that would timeout
|
||||||
|
timeout_node = g.node("TestAsyncTimeout", value=image.out(0), timeout=0.5, operation_time=2.0)
|
||||||
|
g.node("SaveImage", images=timeout_node.out(0))
|
||||||
|
|
||||||
|
try:
|
||||||
|
client.run(g)
|
||||||
|
assert False, "Should have raised a timeout error"
|
||||||
|
except Exception as e:
|
||||||
|
assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}"
|
||||||
|
|
||||||
|
def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test that workflow can recover after async errors."""
|
||||||
|
g = builder
|
||||||
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# First run with error
|
||||||
|
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
|
||||||
|
g.node("SaveImage", images=error_node.out(0))
|
||||||
|
|
||||||
|
try:
|
||||||
|
client.run(g)
|
||||||
|
except Exception:
|
||||||
|
pass # Expected
|
||||||
|
|
||||||
|
# Second run should succeed
|
||||||
|
g2 = GraphBuilder(prefix="recovery_test")
|
||||||
|
image2 = g2.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1)
|
||||||
|
g2.node("SaveImage", images=sleep_node.out(0))
|
||||||
|
|
||||||
|
result = client.run(g2)
|
||||||
|
assert result.did_run(sleep_node), "Should be able to run after error"
|
||||||
|
|
||||||
|
def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test handling when sync node errors while async node is executing."""
|
||||||
|
g = builder
|
||||||
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Async node that takes time
|
||||||
|
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.5)
|
||||||
|
|
||||||
|
# Sync node that will error immediately
|
||||||
|
error_node = g.node("TestSyncError", value=image.out(0))
|
||||||
|
|
||||||
|
# Both feed into output
|
||||||
|
g.node("PreviewImage", images=sleep_node.out(0))
|
||||||
|
g.node("PreviewImage", images=error_node.out(0))
|
||||||
|
|
||||||
|
try:
|
||||||
|
client.run(g)
|
||||||
|
assert False, "Should have raised an error"
|
||||||
|
except Exception as e:
|
||||||
|
# Verify the sync error was caught even though async was running
|
||||||
|
assert 'prompt_id' in e.args[0]
|
||||||
|
|
||||||
|
# Edge Cases
|
||||||
|
|
||||||
|
def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test async nodes with execution blockers."""
|
||||||
|
g = builder
|
||||||
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Async sleep nodes
|
||||||
|
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
|
||||||
|
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
|
||||||
|
|
||||||
|
# Create list of images
|
||||||
|
image_list = g.node("TestMakeListNode", value1=sleep1.out(0), value2=sleep2.out(0))
|
||||||
|
|
||||||
|
# Create list of blocking conditions - [False, True] to block only the second item
|
||||||
|
int1 = g.node("StubInt", value=1)
|
||||||
|
int2 = g.node("StubInt", value=2)
|
||||||
|
block_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0))
|
||||||
|
|
||||||
|
# Compare each value against 2, so first is False (1 != 2) and second is True (2 == 2)
|
||||||
|
compare = g.node("TestIntConditions", a=block_list.out(0), b=2, operation="==")
|
||||||
|
|
||||||
|
# Block based on the comparison results
|
||||||
|
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
|
||||||
|
|
||||||
|
output = g.node("PreviewImage", images=blocker.out(0))
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 1, "Should have blocked second image"
|
||||||
|
|
||||||
|
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test that async nodes are properly cached."""
|
||||||
|
g = builder
|
||||||
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
|
||||||
|
g.node("SaveImage", images=sleep_node.out(0))
|
||||||
|
|
||||||
|
# First run
|
||||||
|
result1 = client.run(g)
|
||||||
|
assert result1.did_run(sleep_node), "Should run first time"
|
||||||
|
|
||||||
|
# Second run - should be cached
|
||||||
|
start_time = time.time()
|
||||||
|
result2 = client.run(g)
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
assert not result2.did_run(sleep_node), "Should be cached"
|
||||||
|
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
|
||||||
|
|
||||||
|
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test async nodes within dynamically generated prompts."""
|
||||||
|
g = builder
|
||||||
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Node that generates async nodes dynamically
|
||||||
|
dynamic_async = g.node("TestDynamicAsyncGeneration",
|
||||||
|
image1=image1.out(0),
|
||||||
|
image2=image2.out(0),
|
||||||
|
num_async_nodes=3,
|
||||||
|
sleep_duration=0.2)
|
||||||
|
g.node("SaveImage", images=dynamic_async.out(0))
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
result = client.run(g)
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Should execute async nodes in parallel within dynamic prompt
|
||||||
|
assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s"
|
||||||
|
assert result.did_run(dynamic_async)
|
||||||
|
|
||||||
|
def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test that async resources are properly cleaned up."""
|
||||||
|
g = builder
|
||||||
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Create multiple async nodes that use resources
|
||||||
|
resource_nodes = []
|
||||||
|
for i in range(5):
|
||||||
|
node = g.node("TestAsyncResourceUser",
|
||||||
|
value=image.out(0),
|
||||||
|
resource_id=f"resource_{i}",
|
||||||
|
duration=0.1)
|
||||||
|
resource_nodes.append(node)
|
||||||
|
g.node("PreviewImage", images=node.out(0))
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
|
||||||
|
# Verify all nodes executed
|
||||||
|
for node in resource_nodes:
|
||||||
|
assert result.did_run(node)
|
||||||
|
|
||||||
|
# Run again to ensure resources were cleaned up
|
||||||
|
result2 = client.run(g)
|
||||||
|
# Should be cached but not error due to resource conflicts
|
||||||
|
for node in resource_nodes:
|
||||||
|
assert not result2.did_run(node), "Should be cached"
|
||||||
|
|
||||||
|
def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test cancellation of async operations."""
|
||||||
|
# This would require implementing cancellation in the client
|
||||||
|
# For now, we'll test that long-running async operations can be interrupted
|
||||||
|
pass # TODO: Implement when cancellation API is available
|
||||||
|
|
||||||
|
def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""Test workflows with both sync and async nodes."""
|
||||||
|
g = builder
|
||||||
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Mix of sync and async operations
|
||||||
|
# Sync: lazy mix images
|
||||||
|
sync_op1 = g.node("TestLazyMixImages", image1=image1.out(0), image2=image2.out(0), mask=mask.out(0))
|
||||||
|
# Async: sleep
|
||||||
|
async_op1 = g.node("TestSleep", value=sync_op1.out(0), seconds=0.2)
|
||||||
|
# Sync: custom validation
|
||||||
|
sync_op2 = g.node("TestCustomValidation1", input1=async_op1.out(0), input2=0.5)
|
||||||
|
# Async: sleep again
|
||||||
|
async_op2 = g.node("TestSleep", value=sync_op2.out(0), seconds=0.2)
|
||||||
|
|
||||||
|
output = g.node("SaveImage", images=async_op2.out(0))
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
|
||||||
|
# Verify all nodes executed in correct order
|
||||||
|
assert result.did_run(sync_op1)
|
||||||
|
assert result.did_run(async_op1)
|
||||||
|
assert result.did_run(sync_op2)
|
||||||
|
assert result.did_run(async_op2)
|
||||||
|
|
||||||
|
# Image should be a mix of black and white (gray)
|
||||||
|
result_images = result.get_images(output)
|
||||||
|
avg_value = np.array(result_images[0]).mean()
|
||||||
|
assert abs(avg_value - 63.75) < 5, f"Average value {avg_value} should be ~63.75"
|
@ -3,6 +3,7 @@ from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DI
|
|||||||
from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS
|
from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS
|
||||||
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
|
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
|
||||||
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
|
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
|
||||||
|
from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||||
|
|
||||||
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
|
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
|
||||||
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
|
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
@ -13,6 +14,7 @@ NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
|
|||||||
NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
|
NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
|
||||||
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
|
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
|
||||||
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
|
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
|
||||||
|
NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS)
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
@ -20,4 +22,5 @@ NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
|
|||||||
NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
|
NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
|
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
|
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
|
|
||||||
|
343
tests/inference/testing_nodes/testing-pack/async_test_nodes.py
Normal file
343
tests/inference/testing_nodes/testing-pack/async_test_nodes.py
Normal file
@ -0,0 +1,343 @@
|
|||||||
|
import torch
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict
|
||||||
|
from comfy.utils import ProgressBar
|
||||||
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
|
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||||
|
from comfy.comfy_types import IO
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncValidation(ComfyNodeABC):
|
||||||
|
"""Test node with async VALIDATE_INPUTS."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": ("FLOAT", {"default": 5.0}),
|
||||||
|
"threshold": ("FLOAT", {"default": 10.0}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "_for_testing/async"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def VALIDATE_INPUTS(cls, value, threshold):
|
||||||
|
# Simulate async validation (e.g., checking remote service)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
if value > threshold:
|
||||||
|
return f"Value {value} exceeds threshold {threshold}"
|
||||||
|
return True
|
||||||
|
|
||||||
|
def process(self, value, threshold):
|
||||||
|
# Create image based on value
|
||||||
|
intensity = value / 10.0
|
||||||
|
image = torch.ones([1, 512, 512, 3]) * intensity
|
||||||
|
return (image,)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncError(ComfyNodeABC):
|
||||||
|
"""Test node that errors during async execution."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": (IO.ANY, {}),
|
||||||
|
"error_after": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.ANY,)
|
||||||
|
FUNCTION = "error_execution"
|
||||||
|
CATEGORY = "_for_testing/async"
|
||||||
|
|
||||||
|
async def error_execution(self, value, error_after):
|
||||||
|
await asyncio.sleep(error_after)
|
||||||
|
raise RuntimeError("Intentional async execution error for testing")
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncValidationError(ComfyNodeABC):
|
||||||
|
"""Test node with async validation that always fails."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": ("FLOAT", {"default": 5.0}),
|
||||||
|
"max_value": ("FLOAT", {"default": 10.0}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "_for_testing/async"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def VALIDATE_INPUTS(cls, value, max_value):
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
# Always fail validation for values > max_value
|
||||||
|
if value > max_value:
|
||||||
|
return f"Async validation failed: {value} > {max_value}"
|
||||||
|
return True
|
||||||
|
|
||||||
|
def process(self, value, max_value):
|
||||||
|
# This won't be reached if validation fails
|
||||||
|
image = torch.ones([1, 512, 512, 3]) * (value / max_value)
|
||||||
|
return (image,)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncTimeout(ComfyNodeABC):
|
||||||
|
"""Test node that simulates timeout scenarios."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": (IO.ANY, {}),
|
||||||
|
"timeout": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}),
|
||||||
|
"operation_time": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.ANY,)
|
||||||
|
FUNCTION = "timeout_execution"
|
||||||
|
CATEGORY = "_for_testing/async"
|
||||||
|
|
||||||
|
async def timeout_execution(self, value, timeout, operation_time):
|
||||||
|
try:
|
||||||
|
# This will timeout if operation_time > timeout
|
||||||
|
await asyncio.wait_for(asyncio.sleep(operation_time), timeout=timeout)
|
||||||
|
return (value,)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise RuntimeError(f"Operation timed out after {timeout} seconds")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncError(ComfyNodeABC):
|
||||||
|
"""Test node that errors synchronously (for mixed sync/async testing)."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": (IO.ANY, {}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.ANY,)
|
||||||
|
FUNCTION = "sync_error"
|
||||||
|
CATEGORY = "_for_testing/async"
|
||||||
|
|
||||||
|
def sync_error(self, value):
|
||||||
|
raise RuntimeError("Intentional sync execution error for testing")
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncLazyCheck(ComfyNodeABC):
|
||||||
|
"""Test node with async check_lazy_status."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"input1": (IO.ANY, {"lazy": True}),
|
||||||
|
"input2": (IO.ANY, {"lazy": True}),
|
||||||
|
"condition": ("BOOLEAN", {"default": True}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "_for_testing/async"
|
||||||
|
|
||||||
|
async def check_lazy_status(self, condition, input1, input2):
|
||||||
|
# Simulate async checking (e.g., querying remote service)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
needed = []
|
||||||
|
if condition and input1 is None:
|
||||||
|
needed.append("input1")
|
||||||
|
if not condition and input2 is None:
|
||||||
|
needed.append("input2")
|
||||||
|
return needed
|
||||||
|
|
||||||
|
def process(self, input1, input2, condition):
|
||||||
|
# Return a simple image
|
||||||
|
return (torch.ones([1, 512, 512, 3]),)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDynamicAsyncGeneration(ComfyNodeABC):
|
||||||
|
"""Test node that dynamically generates async nodes."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image1": ("IMAGE",),
|
||||||
|
"image2": ("IMAGE",),
|
||||||
|
"num_async_nodes": ("INT", {"default": 3, "min": 1, "max": 10}),
|
||||||
|
"sleep_duration": ("FLOAT", {"default": 0.2, "min": 0.1, "max": 1.0}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "generate_async_workflow"
|
||||||
|
CATEGORY = "_for_testing/async"
|
||||||
|
|
||||||
|
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
|
||||||
|
g = GraphBuilder()
|
||||||
|
|
||||||
|
# Create multiple async sleep nodes
|
||||||
|
sleep_nodes = []
|
||||||
|
for i in range(num_async_nodes):
|
||||||
|
image = image1 if i % 2 == 0 else image2
|
||||||
|
sleep_node = g.node("TestSleep", value=image, seconds=sleep_duration)
|
||||||
|
sleep_nodes.append(sleep_node)
|
||||||
|
|
||||||
|
# Average all results
|
||||||
|
if len(sleep_nodes) == 1:
|
||||||
|
final_node = sleep_nodes[0]
|
||||||
|
else:
|
||||||
|
avg_inputs = {"input1": sleep_nodes[0].out(0)}
|
||||||
|
for i, node in enumerate(sleep_nodes[1:], 2):
|
||||||
|
avg_inputs[f"input{i}"] = node.out(0)
|
||||||
|
final_node = g.node("TestVariadicAverage", **avg_inputs)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"result": (final_node.out(0),),
|
||||||
|
"expand": g.finalize(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncResourceUser(ComfyNodeABC):
|
||||||
|
"""Test node that uses resources during async execution."""
|
||||||
|
|
||||||
|
# Class-level resource tracking for testing
|
||||||
|
_active_resources: Dict[str, bool] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": (IO.ANY, {}),
|
||||||
|
"resource_id": ("STRING", {"default": "resource_0"}),
|
||||||
|
"duration": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.ANY,)
|
||||||
|
FUNCTION = "use_resource"
|
||||||
|
CATEGORY = "_for_testing/async"
|
||||||
|
|
||||||
|
async def use_resource(self, value, resource_id, duration):
|
||||||
|
# Check if resource is already in use
|
||||||
|
if self._active_resources.get(resource_id, False):
|
||||||
|
raise RuntimeError(f"Resource {resource_id} is already in use!")
|
||||||
|
|
||||||
|
# Mark resource as in use
|
||||||
|
self._active_resources[resource_id] = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Simulate resource usage
|
||||||
|
await asyncio.sleep(duration)
|
||||||
|
return (value,)
|
||||||
|
finally:
|
||||||
|
# Always clean up resource
|
||||||
|
self._active_resources[resource_id] = False
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncBatchProcessing(ComfyNodeABC):
|
||||||
|
"""Test async processing of batched inputs."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": ("IMAGE",),
|
||||||
|
"process_time_per_item": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 1.0}),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "process_batch"
|
||||||
|
CATEGORY = "_for_testing/async"
|
||||||
|
|
||||||
|
async def process_batch(self, images, process_time_per_item, unique_id):
|
||||||
|
batch_size = images.shape[0]
|
||||||
|
pbar = ProgressBar(batch_size, node_id=unique_id)
|
||||||
|
|
||||||
|
# Process each image in the batch
|
||||||
|
processed = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
# Simulate async processing
|
||||||
|
await asyncio.sleep(process_time_per_item)
|
||||||
|
|
||||||
|
# Simple processing: invert the image
|
||||||
|
processed_image = 1.0 - images[i:i+1]
|
||||||
|
processed.append(processed_image)
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Stack processed images
|
||||||
|
result = torch.cat(processed, dim=0)
|
||||||
|
return (result,)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncConcurrentLimit(ComfyNodeABC):
|
||||||
|
"""Test concurrent execution limits for async nodes."""
|
||||||
|
|
||||||
|
_semaphore = asyncio.Semaphore(2) # Only allow 2 concurrent executions
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": (IO.ANY, {}),
|
||||||
|
"duration": ("FLOAT", {"default": 0.5, "min": 0.1, "max": 2.0}),
|
||||||
|
"node_id": ("INT", {"default": 0}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.ANY,)
|
||||||
|
FUNCTION = "limited_execution"
|
||||||
|
CATEGORY = "_for_testing/async"
|
||||||
|
|
||||||
|
async def limited_execution(self, value, duration, node_id):
|
||||||
|
async with self._semaphore:
|
||||||
|
# Node {node_id} acquired semaphore
|
||||||
|
await asyncio.sleep(duration)
|
||||||
|
# Node {node_id} releasing semaphore
|
||||||
|
return (value,)
|
||||||
|
|
||||||
|
|
||||||
|
# Add node mappings
|
||||||
|
ASYNC_TEST_NODE_CLASS_MAPPINGS = {
|
||||||
|
"TestAsyncValidation": TestAsyncValidation,
|
||||||
|
"TestAsyncError": TestAsyncError,
|
||||||
|
"TestAsyncValidationError": TestAsyncValidationError,
|
||||||
|
"TestAsyncTimeout": TestAsyncTimeout,
|
||||||
|
"TestSyncError": TestSyncError,
|
||||||
|
"TestAsyncLazyCheck": TestAsyncLazyCheck,
|
||||||
|
"TestDynamicAsyncGeneration": TestDynamicAsyncGeneration,
|
||||||
|
"TestAsyncResourceUser": TestAsyncResourceUser,
|
||||||
|
"TestAsyncBatchProcessing": TestAsyncBatchProcessing,
|
||||||
|
"TestAsyncConcurrentLimit": TestAsyncConcurrentLimit,
|
||||||
|
}
|
||||||
|
|
||||||
|
ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"TestAsyncValidation": "Test Async Validation",
|
||||||
|
"TestAsyncError": "Test Async Error",
|
||||||
|
"TestAsyncValidationError": "Test Async Validation Error",
|
||||||
|
"TestAsyncTimeout": "Test Async Timeout",
|
||||||
|
"TestSyncError": "Test Sync Error",
|
||||||
|
"TestAsyncLazyCheck": "Test Async Lazy Check",
|
||||||
|
"TestDynamicAsyncGeneration": "Test Dynamic Async Generation",
|
||||||
|
"TestAsyncResourceUser": "Test Async Resource User",
|
||||||
|
"TestAsyncBatchProcessing": "Test Async Batch Processing",
|
||||||
|
"TestAsyncConcurrentLimit": "Test Async Concurrent Limit",
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user