mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 05:25:23 +00:00
Fix showing progress from other sessions
Because `client_id` was missing from ths `progress_state` message, it was being sent to all connected sessions. This technically meant that if someone had a graph with the same nodes, they would see the progress updates for others. Also added a test to prevent reoccurance and moved the tests around to make CI easier to hook up.
This commit is contained in:
4
tests/execution/extra_model_paths.yaml
Normal file
4
tests/execution/extra_model_paths.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
# Config for testing nodes
|
||||
testing:
|
||||
custom_nodes: testing_nodes
|
||||
|
423
tests/execution/test_async_nodes.py
Normal file
423
tests/execution/test_async_nodes.py
Normal file
@@ -0,0 +1,423 @@
|
||||
import pytest
|
||||
import time
|
||||
import torch
|
||||
import urllib.error
|
||||
import numpy as np
|
||||
import subprocess
|
||||
|
||||
from pytest import fixture
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from tests.execution.test_execution import ComfyClient, run_warmup
|
||||
|
||||
|
||||
@pytest.mark.execution
|
||||
class TestAsyncNodes:
|
||||
@fixture(scope="class", autouse=True, params=[
|
||||
(False, 0),
|
||||
(True, 0),
|
||||
(True, 100),
|
||||
])
|
||||
def _server(self, args_pytest, request):
|
||||
pargs = [
|
||||
'python','main.py',
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
|
||||
'--cpu',
|
||||
]
|
||||
use_lru, lru_size = request.param
|
||||
if use_lru:
|
||||
pargs += ['--cache-lru', str(lru_size)]
|
||||
# Running server with args: pargs
|
||||
p = subprocess.Popen(pargs)
|
||||
yield
|
||||
p.kill()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture(scope="class", autouse=True)
|
||||
def shared_client(self, args_pytest, _server):
|
||||
client = ComfyClient()
|
||||
n_tries = 5
|
||||
for i in range(n_tries):
|
||||
time.sleep(4)
|
||||
try:
|
||||
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
|
||||
except ConnectionRefusedError:
|
||||
# Retrying...
|
||||
pass
|
||||
else:
|
||||
break
|
||||
yield client
|
||||
del client
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture
|
||||
def client(self, shared_client, request):
|
||||
shared_client.set_test_name(f"async_nodes[{request.node.name}]")
|
||||
yield shared_client
|
||||
|
||||
@fixture
|
||||
def builder(self, request):
|
||||
yield GraphBuilder(prefix=request.node.name)
|
||||
|
||||
# Happy Path Tests
|
||||
|
||||
def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that a basic async node executes correctly."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1)
|
||||
output = g.node("SaveImage", images=sleep_node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify execution completed
|
||||
assert result.did_run(sleep_node), "Async sleep node should have executed"
|
||||
assert result.did_run(output), "Output node should have executed"
|
||||
|
||||
# Verify the image passed through correctly
|
||||
result_images = result.get_images(output)
|
||||
assert len(result_images) == 1, "Should have 1 image"
|
||||
assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black"
|
||||
|
||||
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that multiple async nodes execute in parallel."""
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client)
|
||||
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create multiple async sleep nodes with different durations
|
||||
sleep1 = g.node("TestSleep", value=image.out(0), seconds=0.3)
|
||||
sleep2 = g.node("TestSleep", value=image.out(0), seconds=0.4)
|
||||
sleep3 = g.node("TestSleep", value=image.out(0), seconds=0.5)
|
||||
|
||||
# Add outputs for each
|
||||
_output1 = g.node("PreviewImage", images=sleep1.out(0))
|
||||
_output2 = g.node("PreviewImage", images=sleep2.out(0))
|
||||
_output3 = g.node("PreviewImage", images=sleep3.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should take ~0.5s (max duration) not 1.2s (sum of durations)
|
||||
assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s"
|
||||
|
||||
# Verify all nodes executed
|
||||
assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3)
|
||||
|
||||
def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with proper dependency handling."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Chain of async operations
|
||||
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
|
||||
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
|
||||
|
||||
# Average depends on both async results
|
||||
average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0))
|
||||
output = g.node("SaveImage", images=average.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify execution order
|
||||
assert result.did_run(sleep1) and result.did_run(sleep2)
|
||||
assert result.did_run(average) and result.did_run(output)
|
||||
|
||||
# Verify averaged result
|
||||
result_images = result.get_images(output)
|
||||
avg_value = np.array(result_images[0]).mean()
|
||||
assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5"
|
||||
|
||||
def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async VALIDATE_INPUTS function."""
|
||||
g = builder
|
||||
# Create a test node with async validation
|
||||
validation_node = g.node("TestAsyncValidation", value=5.0, threshold=10.0)
|
||||
g.node("SaveImage", images=validation_node.out(0))
|
||||
|
||||
# Should pass validation
|
||||
result = client.run(g)
|
||||
assert result.did_run(validation_node)
|
||||
|
||||
# Test validation failure
|
||||
validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold
|
||||
with pytest.raises(urllib.error.HTTPError):
|
||||
client.run(g)
|
||||
|
||||
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with lazy evaluation."""
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client, prefix="warmup_lazy")
|
||||
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1)
|
||||
|
||||
# Create async nodes that will be evaluated lazily
|
||||
sleep1 = g.node("TestSleep", value=input1.out(0), seconds=0.3)
|
||||
sleep2 = g.node("TestSleep", value=input2.out(0), seconds=0.3)
|
||||
|
||||
# Use lazy mix that only needs sleep1 (mask=0.0)
|
||||
lazy_mix = g.node("TestLazyMixImages", image1=sleep1.out(0), image2=sleep2.out(0), mask=mask.out(0))
|
||||
g.node("SaveImage", images=lazy_mix.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should only execute sleep1, not sleep2
|
||||
assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s"
|
||||
assert result.did_run(sleep1), "Sleep1 should have executed"
|
||||
assert not result.did_run(sleep2), "Sleep2 should have been skipped"
|
||||
|
||||
def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async check_lazy_status function."""
|
||||
g = builder
|
||||
# Create a node with async check_lazy_status
|
||||
lazy_node = g.node("TestAsyncLazyCheck",
|
||||
input1="value1",
|
||||
input2="value2",
|
||||
condition=True)
|
||||
g.node("SaveImage", images=lazy_node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
assert result.did_run(lazy_node)
|
||||
|
||||
# Error Handling Tests
|
||||
|
||||
def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async execution errors are properly handled."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
# Create an async node that will error
|
||||
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
|
||||
g.node("SaveImage", images=error_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
|
||||
assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node"
|
||||
|
||||
def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async validation error handling."""
|
||||
g = builder
|
||||
# Node with async validation that will fail
|
||||
validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0)
|
||||
g.node("SaveImage", images=validation_node.out(0))
|
||||
|
||||
with pytest.raises(urllib.error.HTTPError) as exc_info:
|
||||
client.run(g)
|
||||
# Verify it's a validation error
|
||||
assert exc_info.value.code == 400
|
||||
|
||||
def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test handling of async operations that timeout."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
# Very long sleep that would timeout
|
||||
timeout_node = g.node("TestAsyncTimeout", value=image.out(0), timeout=0.5, operation_time=2.0)
|
||||
g.node("SaveImage", images=timeout_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised a timeout error"
|
||||
except Exception as e:
|
||||
assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}"
|
||||
|
||||
def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that workflow can recover after async errors."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# First run with error
|
||||
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
|
||||
g.node("SaveImage", images=error_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
except Exception:
|
||||
pass # Expected
|
||||
|
||||
# Second run should succeed
|
||||
g2 = GraphBuilder(prefix="recovery_test")
|
||||
image2 = g2.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1)
|
||||
g2.node("SaveImage", images=sleep_node.out(0))
|
||||
|
||||
result = client.run(g2)
|
||||
assert result.did_run(sleep_node), "Should be able to run after error"
|
||||
|
||||
def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test handling when sync node errors while async node is executing."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Async node that takes time
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.5)
|
||||
|
||||
# Sync node that will error immediately
|
||||
error_node = g.node("TestSyncError", value=image.out(0))
|
||||
|
||||
# Both feed into output
|
||||
g.node("PreviewImage", images=sleep_node.out(0))
|
||||
g.node("PreviewImage", images=error_node.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
# Verify the sync error was caught even though async was running
|
||||
assert 'prompt_id' in e.args[0]
|
||||
|
||||
# Edge Cases
|
||||
|
||||
def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes with execution blockers."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Async sleep nodes
|
||||
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
|
||||
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
|
||||
|
||||
# Create list of images
|
||||
image_list = g.node("TestMakeListNode", value1=sleep1.out(0), value2=sleep2.out(0))
|
||||
|
||||
# Create list of blocking conditions - [False, True] to block only the second item
|
||||
int1 = g.node("StubInt", value=1)
|
||||
int2 = g.node("StubInt", value=2)
|
||||
block_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0))
|
||||
|
||||
# Compare each value against 2, so first is False (1 != 2) and second is True (2 == 2)
|
||||
compare = g.node("TestIntConditions", a=block_list.out(0), b=2, operation="==")
|
||||
|
||||
# Block based on the comparison results
|
||||
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
|
||||
|
||||
output = g.node("PreviewImage", images=blocker.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have blocked second image"
|
||||
|
||||
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async nodes are properly cached."""
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client, prefix="warmup_cache")
|
||||
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
|
||||
g.node("SaveImage", images=sleep_node.out(0))
|
||||
|
||||
# First run
|
||||
result1 = client.run(g)
|
||||
assert result1.did_run(sleep_node), "Should run first time"
|
||||
|
||||
# Second run - should be cached
|
||||
start_time = time.time()
|
||||
result2 = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
assert not result2.did_run(sleep_node), "Should be cached"
|
||||
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
|
||||
|
||||
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test async nodes within dynamically generated prompts."""
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client, prefix="warmup_dynamic")
|
||||
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Node that generates async nodes dynamically
|
||||
dynamic_async = g.node("TestDynamicAsyncGeneration",
|
||||
image1=image1.out(0),
|
||||
image2=image2.out(0),
|
||||
num_async_nodes=3,
|
||||
sleep_duration=0.2)
|
||||
g.node("SaveImage", images=dynamic_async.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Should execute async nodes in parallel within dynamic prompt
|
||||
assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s"
|
||||
assert result.did_run(dynamic_async)
|
||||
|
||||
def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that async resources are properly cleaned up."""
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create multiple async nodes that use resources
|
||||
resource_nodes = []
|
||||
for i in range(5):
|
||||
node = g.node("TestAsyncResourceUser",
|
||||
value=image.out(0),
|
||||
resource_id=f"resource_{i}",
|
||||
duration=0.1)
|
||||
resource_nodes.append(node)
|
||||
g.node("PreviewImage", images=node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify all nodes executed
|
||||
for node in resource_nodes:
|
||||
assert result.did_run(node)
|
||||
|
||||
# Run again to ensure resources were cleaned up
|
||||
result2 = client.run(g)
|
||||
# Should be cached but not error due to resource conflicts
|
||||
for node in resource_nodes:
|
||||
assert not result2.did_run(node), "Should be cached"
|
||||
|
||||
def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test cancellation of async operations."""
|
||||
# This would require implementing cancellation in the client
|
||||
# For now, we'll test that long-running async operations can be interrupted
|
||||
pass # TODO: Implement when cancellation API is available
|
||||
|
||||
def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test workflows with both sync and async nodes."""
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
|
||||
# Mix of sync and async operations
|
||||
# Sync: lazy mix images
|
||||
sync_op1 = g.node("TestLazyMixImages", image1=image1.out(0), image2=image2.out(0), mask=mask.out(0))
|
||||
# Async: sleep
|
||||
async_op1 = g.node("TestSleep", value=sync_op1.out(0), seconds=0.2)
|
||||
# Sync: custom validation
|
||||
sync_op2 = g.node("TestCustomValidation1", input1=async_op1.out(0), input2=0.5)
|
||||
# Async: sleep again
|
||||
async_op2 = g.node("TestSleep", value=sync_op2.out(0), seconds=0.2)
|
||||
|
||||
output = g.node("SaveImage", images=async_op2.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Verify all nodes executed in correct order
|
||||
assert result.did_run(sync_op1)
|
||||
assert result.did_run(async_op1)
|
||||
assert result.did_run(sync_op2)
|
||||
assert result.did_run(async_op2)
|
||||
|
||||
# Image should be a mix of black and white (gray)
|
||||
result_images = result.get_images(output)
|
||||
avg_value = np.array(result_images[0]).mean()
|
||||
assert abs(avg_value - 63.75) < 5, f"Average value {avg_value} should be ~63.75"
|
761
tests/execution/test_execution.py
Normal file
761
tests/execution/test_execution.py
Normal file
@@ -0,0 +1,761 @@
|
||||
from io import BytesIO
|
||||
import numpy
|
||||
from PIL import Image
|
||||
import pytest
|
||||
from pytest import fixture
|
||||
import time
|
||||
import torch
|
||||
from typing import Union, Dict
|
||||
import json
|
||||
import subprocess
|
||||
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
||||
import uuid
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
import urllib.error
|
||||
from comfy_execution.graph_utils import GraphBuilder, Node
|
||||
|
||||
def run_warmup(client, prefix="warmup"):
|
||||
"""Run a simple workflow to warm up the server."""
|
||||
warmup_g = GraphBuilder(prefix=prefix)
|
||||
warmup_image = warmup_g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
|
||||
warmup_g.node("PreviewImage", images=warmup_image.out(0))
|
||||
client.run(warmup_g)
|
||||
|
||||
class RunResult:
|
||||
def __init__(self, prompt_id: str):
|
||||
self.outputs: Dict[str,Dict] = {}
|
||||
self.runs: Dict[str,bool] = {}
|
||||
self.cached: Dict[str,bool] = {}
|
||||
self.prompt_id: str = prompt_id
|
||||
|
||||
def get_output(self, node: Node):
|
||||
return self.outputs.get(node.id, None)
|
||||
|
||||
def did_run(self, node: Node):
|
||||
return self.runs.get(node.id, False)
|
||||
|
||||
def was_cached(self, node: Node):
|
||||
return self.cached.get(node.id, False)
|
||||
|
||||
def was_executed(self, node: Node):
|
||||
"""Returns True if node was either run or cached"""
|
||||
return self.did_run(node) or self.was_cached(node)
|
||||
|
||||
def get_images(self, node: Node):
|
||||
output = self.get_output(node)
|
||||
if output is None:
|
||||
return []
|
||||
return output.get('image_objects', [])
|
||||
|
||||
def get_prompt_id(self):
|
||||
return self.prompt_id
|
||||
|
||||
class ComfyClient:
|
||||
def __init__(self):
|
||||
self.test_name = ""
|
||||
|
||||
def connect(self,
|
||||
listen:str = '127.0.0.1',
|
||||
port:Union[str,int] = 8188,
|
||||
client_id: str = str(uuid.uuid4())
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.server_address = f"{listen}:{port}"
|
||||
ws = websocket.WebSocket()
|
||||
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
|
||||
self.ws = ws
|
||||
|
||||
def queue_prompt(self, prompt, partial_execution_targets=None):
|
||||
p = {"prompt": prompt, "client_id": self.client_id}
|
||||
if partial_execution_targets is not None:
|
||||
p["partial_execution_targets"] = partial_execution_targets
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
|
||||
return json.loads(urllib.request.urlopen(req).read())
|
||||
|
||||
def get_image(self, filename, subfolder, folder_type):
|
||||
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||
url_values = urllib.parse.urlencode(data)
|
||||
with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response:
|
||||
return response.read()
|
||||
|
||||
def get_history(self, prompt_id):
|
||||
with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response:
|
||||
return json.loads(response.read())
|
||||
|
||||
def set_test_name(self, name):
|
||||
self.test_name = name
|
||||
|
||||
def run(self, graph, partial_execution_targets=None):
|
||||
prompt = graph.finalize()
|
||||
for node in graph.nodes.values():
|
||||
if node.class_type == 'SaveImage':
|
||||
node.inputs['filename_prefix'] = self.test_name
|
||||
|
||||
prompt_id = self.queue_prompt(prompt, partial_execution_targets)['prompt_id']
|
||||
result = RunResult(prompt_id)
|
||||
while True:
|
||||
out = self.ws.recv()
|
||||
if isinstance(out, str):
|
||||
message = json.loads(out)
|
||||
if message['type'] == 'executing':
|
||||
data = message['data']
|
||||
if data['prompt_id'] != prompt_id:
|
||||
continue
|
||||
if data['node'] is None:
|
||||
break
|
||||
result.runs[data['node']] = True
|
||||
elif message['type'] == 'execution_error':
|
||||
raise Exception(message['data'])
|
||||
elif message['type'] == 'execution_cached':
|
||||
if message['data']['prompt_id'] == prompt_id:
|
||||
cached_nodes = message['data'].get('nodes', [])
|
||||
for node_id in cached_nodes:
|
||||
result.cached[node_id] = True
|
||||
|
||||
history = self.get_history(prompt_id)[prompt_id]
|
||||
for node_id in history['outputs']:
|
||||
node_output = history['outputs'][node_id]
|
||||
result.outputs[node_id] = node_output
|
||||
images_output = []
|
||||
if 'images' in node_output:
|
||||
for image in node_output['images']:
|
||||
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
||||
image_obj = Image.open(BytesIO(image_data))
|
||||
images_output.append(image_obj)
|
||||
node_output['image_objects'] = images_output
|
||||
|
||||
return result
|
||||
|
||||
#
|
||||
# Loop through these variables
|
||||
#
|
||||
@pytest.mark.execution
|
||||
class TestExecution:
|
||||
#
|
||||
# Initialize server and client
|
||||
#
|
||||
@fixture(scope="class", autouse=True, params=[
|
||||
# (use_lru, lru_size)
|
||||
(False, 0),
|
||||
(True, 0),
|
||||
(True, 100),
|
||||
])
|
||||
def _server(self, args_pytest, request):
|
||||
# Start server
|
||||
pargs = [
|
||||
'python','main.py',
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
|
||||
'--cpu',
|
||||
]
|
||||
use_lru, lru_size = request.param
|
||||
if use_lru:
|
||||
pargs += ['--cache-lru', str(lru_size)]
|
||||
print("Running server with args:", pargs) # noqa: T201
|
||||
p = subprocess.Popen(pargs)
|
||||
yield
|
||||
p.kill()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def start_client(self, listen:str, port:int):
|
||||
# Start client
|
||||
comfy_client = ComfyClient()
|
||||
# Connect to server (with retries)
|
||||
n_tries = 5
|
||||
for i in range(n_tries):
|
||||
time.sleep(4)
|
||||
try:
|
||||
comfy_client.connect(listen=listen, port=port)
|
||||
except ConnectionRefusedError as e:
|
||||
print(e) # noqa: T201
|
||||
print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201
|
||||
else:
|
||||
break
|
||||
return comfy_client
|
||||
|
||||
@fixture(scope="class", autouse=True)
|
||||
def shared_client(self, args_pytest, _server):
|
||||
client = self.start_client(args_pytest["listen"], args_pytest["port"])
|
||||
yield client
|
||||
del client
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture
|
||||
def client(self, shared_client, request):
|
||||
shared_client.set_test_name(f"execution[{request.node.name}]")
|
||||
yield shared_client
|
||||
|
||||
@fixture
|
||||
def builder(self, request):
|
||||
yield GraphBuilder(prefix=request.node.name)
|
||||
|
||||
def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1)
|
||||
|
||||
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||
output = g.node("SaveImage", images=lazy_mix.out(0))
|
||||
result = client.run(g)
|
||||
|
||||
result_image = result.get_images(output)[0]
|
||||
assert numpy.array(result_image).any() == 0, "Image should be black"
|
||||
assert result.did_run(input1)
|
||||
assert not result.did_run(input2)
|
||||
assert result.did_run(mask)
|
||||
assert result.did_run(lazy_mix)
|
||||
|
||||
def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
|
||||
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||
g.node("SaveImage", images=lazy_mix.out(0))
|
||||
|
||||
client.run(g)
|
||||
result2 = client.run(g)
|
||||
for node_id, node in g.nodes.items():
|
||||
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
|
||||
|
||||
def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
|
||||
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||
g.node("SaveImage", images=lazy_mix.out(0))
|
||||
|
||||
client.run(g)
|
||||
mask.inputs['value'] = 0.4
|
||||
result2 = client.run(g)
|
||||
assert not result2.did_run(input1), "Input1 should have been cached"
|
||||
assert not result2.did_run(input2), "Input2 should have been cached"
|
||||
|
||||
def test_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
# Different size of the two images
|
||||
input2 = g.node("StubImage", content="NOISE", height=256, width=256, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
|
||||
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||
g.node("SaveImage", images=lazy_mix.out(0))
|
||||
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
|
||||
|
||||
@pytest.mark.parametrize("test_value, expect_error", [
|
||||
(5, True),
|
||||
("foo", True),
|
||||
(5.0, False),
|
||||
])
|
||||
def test_validation_error_literal(self, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0)
|
||||
g.node("SaveImage", images=validation1.out(0))
|
||||
|
||||
if expect_error:
|
||||
with pytest.raises(urllib.error.HTTPError):
|
||||
client.run(g)
|
||||
else:
|
||||
client.run(g)
|
||||
|
||||
@pytest.mark.parametrize("test_type, test_value", [
|
||||
("StubInt", 5),
|
||||
("StubMask", 5.0)
|
||||
])
|
||||
def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
stub = g.node(test_type, value=test_value)
|
||||
validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0)
|
||||
g.node("SaveImage", images=validation1.out(0))
|
||||
|
||||
with pytest.raises(urllib.error.HTTPError):
|
||||
client.run(g)
|
||||
|
||||
@pytest.mark.parametrize("test_type, test_value, expect_error", [
|
||||
("StubInt", 5, True),
|
||||
("StubFloat", 5.0, False)
|
||||
])
|
||||
def test_validation_error_edge2(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
stub = g.node(test_type, value=test_value)
|
||||
validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0)
|
||||
g.node("SaveImage", images=validation2.out(0))
|
||||
|
||||
if expect_error:
|
||||
with pytest.raises(urllib.error.HTTPError):
|
||||
client.run(g)
|
||||
else:
|
||||
client.run(g)
|
||||
|
||||
@pytest.mark.parametrize("test_type, test_value, expect_error", [
|
||||
("StubInt", 5, True),
|
||||
("StubFloat", 5.0, False)
|
||||
])
|
||||
def test_validation_error_edge3(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
stub = g.node(test_type, value=test_value)
|
||||
validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0)
|
||||
g.node("SaveImage", images=validation3.out(0))
|
||||
|
||||
if expect_error:
|
||||
with pytest.raises(urllib.error.HTTPError):
|
||||
client.run(g)
|
||||
else:
|
||||
client.run(g)
|
||||
|
||||
@pytest.mark.parametrize("test_type, test_value, expect_error", [
|
||||
("StubInt", 5, True),
|
||||
("StubFloat", 5.0, False)
|
||||
])
|
||||
def test_validation_error_edge4(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
stub = g.node(test_type, value=test_value)
|
||||
validation4 = g.node("TestCustomValidation4", input1=stub.out(0), input2=3.0)
|
||||
g.node("SaveImage", images=validation4.out(0))
|
||||
|
||||
if expect_error:
|
||||
with pytest.raises(urllib.error.HTTPError):
|
||||
client.run(g)
|
||||
else:
|
||||
client.run(g)
|
||||
|
||||
@pytest.mark.parametrize("test_value1, test_value2, expect_error", [
|
||||
(0.0, 0.5, False),
|
||||
(0.0, 5.0, False),
|
||||
(0.0, 7.0, True)
|
||||
])
|
||||
def test_validation_error_kwargs(self, test_value1, test_value2, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
validation5 = g.node("TestCustomValidation5", input1=test_value1, input2=test_value2)
|
||||
g.node("SaveImage", images=validation5.out(0))
|
||||
|
||||
if expect_error:
|
||||
with pytest.raises(urllib.error.HTTPError):
|
||||
client.run(g)
|
||||
else:
|
||||
client.run(g)
|
||||
|
||||
def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
|
||||
lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), mask=mask.out(0))
|
||||
lazy_mix2 = g.node("TestLazyMixImages", image1=lazy_mix1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||
g.node("SaveImage", images=lazy_mix2.out(0))
|
||||
|
||||
# When the cycle exists on initial submission, it should raise a validation error
|
||||
with pytest.raises(urllib.error.HTTPError):
|
||||
client.run(g)
|
||||
|
||||
def test_dynamic_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
generator = g.node("TestDynamicDependencyCycle", input1=input1.out(0), input2=input2.out(0))
|
||||
g.node("SaveImage", images=generator.out(0))
|
||||
|
||||
# When the cycle is in a graph that is generated dynamically, it should raise a runtime error
|
||||
try:
|
||||
client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
|
||||
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
|
||||
|
||||
def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||
input3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||
mix2 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input3.out(0), mask=mask.out(0))
|
||||
# We have multiple outputs. The first is invalid, but the second is valid
|
||||
g.node("SaveImage", images=mix1.out(0))
|
||||
g.node("SaveImage", images=mix2.out(0))
|
||||
g.remove_node("removeme")
|
||||
|
||||
client.run(g)
|
||||
|
||||
# Add back in the missing node to make sure the error doesn't break the server
|
||||
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||
client.run(g)
|
||||
|
||||
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
# Creating the nodes in this specific order previously caused a bug
|
||||
save = g.node("SaveImage")
|
||||
is_changed = g.node("TestCustomIsChanged", should_change=False)
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
save.set_input('images', is_changed.out(0))
|
||||
is_changed.set_input('image', input1.out(0))
|
||||
|
||||
result1 = client.run(g)
|
||||
result2 = client.run(g)
|
||||
is_changed.set_input('should_change', True)
|
||||
result3 = client.run(g)
|
||||
result4 = client.run(g)
|
||||
assert result1.did_run(is_changed), "is_changed should have been run"
|
||||
assert not result2.did_run(is_changed), "is_changed should have been cached"
|
||||
assert result3.did_run(is_changed), "is_changed should have been re-run"
|
||||
assert result4.did_run(is_changed), "is_changed should not have been cached"
|
||||
|
||||
def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
input3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input4 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
average = g.node("TestVariadicAverage", input1=input1.out(0), input2=input2.out(0), input3=input3.out(0), input4=input4.out(0))
|
||||
output = g.node("SaveImage", images=average.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
result_image = result.get_images(output)[0]
|
||||
expected = 255 // 4
|
||||
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
|
||||
|
||||
def test_for_loop(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
iterations = 4
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
is_changed = g.node("TestCustomIsChanged", should_change=True, image=input2.out(0))
|
||||
for_open = g.node("TestForLoopOpen", remaining=iterations, initial_value1=is_changed.out(0))
|
||||
average = g.node("TestVariadicAverage", input1=input1.out(0), input2=for_open.out(2))
|
||||
for_close = g.node("TestForLoopClose", flow_control=for_open.out(0), initial_value1=average.out(0))
|
||||
output = g.node("SaveImage", images=for_close.out(0))
|
||||
|
||||
for iterations in range(1, 5):
|
||||
for_open.set_input('remaining', iterations)
|
||||
result = client.run(g)
|
||||
result_image = result.get_images(output)[0]
|
||||
expected = 255 // (2 ** iterations)
|
||||
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
|
||||
assert result.did_run(is_changed)
|
||||
|
||||
def test_mixed_expansion_returns(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
val_list = g.node("TestMakeListNode", value1=0.1, value2=0.2, value3=0.3)
|
||||
mixed = g.node("TestMixedExpansionReturns", input1=val_list.out(0))
|
||||
output_dynamic = g.node("SaveImage", images=mixed.out(0))
|
||||
output_literal = g.node("SaveImage", images=mixed.out(1))
|
||||
|
||||
result = client.run(g)
|
||||
images_dynamic = result.get_images(output_dynamic)
|
||||
assert len(images_dynamic) == 3, "Should have 2 images"
|
||||
assert numpy.array(images_dynamic[0]).min() == 25 and numpy.array(images_dynamic[0]).max() == 25, "First image should be 0.1"
|
||||
assert numpy.array(images_dynamic[1]).min() == 51 and numpy.array(images_dynamic[1]).max() == 51, "Second image should be 0.2"
|
||||
assert numpy.array(images_dynamic[2]).min() == 76 and numpy.array(images_dynamic[2]).max() == 76, "Third image should be 0.3"
|
||||
|
||||
images_literal = result.get_images(output_literal)
|
||||
assert len(images_literal) == 3, "Should have 2 images"
|
||||
for i in range(3):
|
||||
assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white"
|
||||
|
||||
def test_mixed_lazy_results(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
val_list = g.node("TestMakeListNode", value1=0.0, value2=0.5, value3=1.0)
|
||||
mask = g.node("StubMask", value=val_list.out(0), height=512, width=512, batch_size=1)
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||
rebatch = g.node("RebatchImages", images=mix.out(0), batch_size=3)
|
||||
output = g.node("SaveImage", images=rebatch.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 3, "Should have 3 image"
|
||||
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be 0.0"
|
||||
assert numpy.array(images[1]).min() == 127 and numpy.array(images[1]).max() == 127, "Second image should be 0.5"
|
||||
assert numpy.array(images[2]).min() == 255 and numpy.array(images[2]).max() == 255, "Third image should be 1.0"
|
||||
|
||||
def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
output1 = g.node("SaveImage", images=input1.out(0))
|
||||
output2 = g.node("SaveImage", images=input1.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
images1 = result.get_images(output1)
|
||||
images2 = result.get_images(output2)
|
||||
assert len(images1) == 1, "Should have 1 image"
|
||||
assert len(images2) == 1, "Should have 1 image"
|
||||
|
||||
|
||||
# This tests that only constant outputs are used in the call to `IS_CHANGED`
|
||||
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
|
||||
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
|
||||
|
||||
output = g.node("PreviewImage", images=test_node.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have 1 image"
|
||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||
|
||||
result = client.run(g)
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have 1 image"
|
||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||
assert not result.did_run(test_node), "The execution should have been cached"
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client)
|
||||
|
||||
g = builder
|
||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create sleep nodes for each duration
|
||||
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.9)
|
||||
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=3.1)
|
||||
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
|
||||
|
||||
# Add outputs to verify the execution
|
||||
_output1 = g.node("PreviewImage", images=sleep_node1.out(0))
|
||||
_output2 = g.node("PreviewImage", images=sleep_node2.out(0))
|
||||
_output3 = g.node("PreviewImage", images=sleep_node3.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# The test should take around 3.0 seconds (the longest sleep duration)
|
||||
# plus some overhead, but definitely less than the sum of all sleeps (9.0s)
|
||||
assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s"
|
||||
|
||||
# Verify that all nodes executed
|
||||
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
|
||||
assert result.did_run(sleep_node2), "Sleep node 2 should have run"
|
||||
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
|
||||
|
||||
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client)
|
||||
|
||||
g = builder
|
||||
# Create input images with different values
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
image3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create a TestParallelSleep node that expands into multiple TestSleep nodes
|
||||
parallel_sleep = g.node("TestParallelSleep",
|
||||
image1=image1.out(0),
|
||||
image2=image2.out(0),
|
||||
image3=image3.out(0),
|
||||
sleep1=4.8,
|
||||
sleep2=4.9,
|
||||
sleep3=5.0)
|
||||
output = g.node("SaveImage", images=parallel_sleep.out(0))
|
||||
|
||||
start_time = time.time()
|
||||
result = client.run(g)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Similar to the previous test, expect parallel execution of the sleep nodes
|
||||
# which should complete in less than the sum of all sleeps
|
||||
assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s"
|
||||
|
||||
# Verify the parallel sleep node executed
|
||||
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
|
||||
|
||||
# Verify we get an image as output (blend of the three input images)
|
||||
result_images = result.get_images(output)
|
||||
assert len(result_images) == 1, "Should have 1 image"
|
||||
# Average pixel value should be around 170 (255 * 2 // 3)
|
||||
avg_value = numpy.array(result_images[0]).mean()
|
||||
assert avg_value == 170, f"Image average value {avg_value} should be 170"
|
||||
|
||||
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
|
||||
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
|
||||
# only that one entry in the list is blocked.
|
||||
def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
image3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image_list = g.node("TestMakeListNode", value1=image1.out(0), value2=image2.out(0), value3=image3.out(0))
|
||||
int1 = g.node("StubInt", value=1)
|
||||
int2 = g.node("StubInt", value=2)
|
||||
int3 = g.node("StubInt", value=3)
|
||||
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
|
||||
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
|
||||
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
|
||||
|
||||
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
|
||||
output = g.node("PreviewImage", images=list_output.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
assert result.did_run(output), "The execution should have run"
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 2, "Should have 2 images"
|
||||
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
|
||||
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
|
||||
|
||||
# Output nodes included in the partial execution list are executed
|
||||
def test_partial_execution_included_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create two separate output nodes
|
||||
output1 = g.node("SaveImage", images=input1.out(0))
|
||||
output2 = g.node("SaveImage", images=input2.out(0))
|
||||
|
||||
# Run with partial execution targeting only output1
|
||||
result = client.run(g, partial_execution_targets=[output1.id])
|
||||
|
||||
assert result.was_executed(input1), "Input1 should have been executed (run or cached)"
|
||||
assert result.was_executed(output1), "Output1 should have been executed (run or cached)"
|
||||
assert not result.did_run(input2), "Input2 should not have run"
|
||||
assert not result.did_run(output2), "Output2 should not have run"
|
||||
|
||||
# Verify only output1 produced results
|
||||
assert len(result.get_images(output1)) == 1, "Output1 should have produced an image"
|
||||
assert len(result.get_images(output2)) == 0, "Output2 should not have produced an image"
|
||||
|
||||
# Output nodes NOT included in the partial execution list are NOT executed
|
||||
def test_partial_execution_excluded_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create three output nodes
|
||||
output1 = g.node("SaveImage", images=input1.out(0))
|
||||
output2 = g.node("SaveImage", images=input2.out(0))
|
||||
output3 = g.node("SaveImage", images=input3.out(0))
|
||||
|
||||
# Run with partial execution targeting only output1 and output3
|
||||
result = client.run(g, partial_execution_targets=[output1.id, output3.id])
|
||||
|
||||
assert result.was_executed(input1), "Input1 should have been executed"
|
||||
assert result.was_executed(input3), "Input3 should have been executed"
|
||||
assert result.was_executed(output1), "Output1 should have been executed"
|
||||
assert result.was_executed(output3), "Output3 should have been executed"
|
||||
assert not result.did_run(input2), "Input2 should not have run"
|
||||
assert not result.did_run(output2), "Output2 should not have run"
|
||||
|
||||
# Output nodes NOT in list ARE executed if necessary for nodes that are in the list
|
||||
def test_partial_execution_dependencies(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create a processing chain with an OUTPUT_NODE that has socket outputs
|
||||
output_with_socket = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=2.0)
|
||||
|
||||
# Create another node that depends on the output_with_socket
|
||||
dependent_node = g.node("TestLazyMixImages",
|
||||
image1=output_with_socket.out(0),
|
||||
image2=input1.out(0),
|
||||
mask=g.node("StubMask", value=0.5, height=512, width=512, batch_size=1).out(0))
|
||||
|
||||
# Create the final output
|
||||
final_output = g.node("SaveImage", images=dependent_node.out(0))
|
||||
|
||||
# Run with partial execution targeting only the final output
|
||||
result = client.run(g, partial_execution_targets=[final_output.id])
|
||||
|
||||
# All nodes should have been executed because they're dependencies
|
||||
assert result.was_executed(input1), "Input1 should have been executed"
|
||||
assert result.was_executed(output_with_socket), "Output with socket should have been executed (dependency)"
|
||||
assert result.was_executed(dependent_node), "Dependent node should have been executed"
|
||||
assert result.was_executed(final_output), "Final output should have been executed"
|
||||
|
||||
# Lazy execution works with partial execution
|
||||
def test_partial_execution_with_lazy_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create masks that will trigger different lazy execution paths
|
||||
mask1 = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) # Will only need image1
|
||||
mask2 = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) # Will need both images
|
||||
|
||||
# Create two lazy mix nodes
|
||||
lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask1.out(0))
|
||||
lazy_mix2 = g.node("TestLazyMixImages", image1=input2.out(0), image2=input3.out(0), mask=mask2.out(0))
|
||||
|
||||
output1 = g.node("SaveImage", images=lazy_mix1.out(0))
|
||||
output2 = g.node("SaveImage", images=lazy_mix2.out(0))
|
||||
|
||||
# Run with partial execution targeting only output1
|
||||
result = client.run(g, partial_execution_targets=[output1.id])
|
||||
|
||||
# For output1 path - only input1 should run due to lazy evaluation (mask=0.0)
|
||||
assert result.was_executed(input1), "Input1 should have been executed"
|
||||
assert not result.did_run(input2), "Input2 should not have run (lazy evaluation)"
|
||||
assert result.was_executed(mask1), "Mask1 should have been executed"
|
||||
assert result.was_executed(lazy_mix1), "Lazy mix1 should have been executed"
|
||||
assert result.was_executed(output1), "Output1 should have been executed"
|
||||
|
||||
# Nothing from output2 path should run
|
||||
assert not result.did_run(input3), "Input3 should not have run"
|
||||
assert not result.did_run(mask2), "Mask2 should not have run"
|
||||
assert not result.did_run(lazy_mix2), "Lazy mix2 should not have run"
|
||||
assert not result.did_run(output2), "Output2 should not have run"
|
||||
|
||||
# Multiple OUTPUT_NODEs with dependencies
|
||||
def test_partial_execution_multiple_output_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
|
||||
# Create a chain of OUTPUT_NODEs
|
||||
output_node1 = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=1.5)
|
||||
output_node2 = g.node("TestOutputNodeWithSocketOutput", image=output_node1.out(0), value=2.0)
|
||||
|
||||
# Create regular output nodes
|
||||
save1 = g.node("SaveImage", images=output_node1.out(0))
|
||||
save2 = g.node("SaveImage", images=output_node2.out(0))
|
||||
save3 = g.node("SaveImage", images=input2.out(0))
|
||||
|
||||
# Run targeting only save2
|
||||
result = client.run(g, partial_execution_targets=[save2.id])
|
||||
|
||||
# Should run: input1, output_node1, output_node2, save2
|
||||
assert result.was_executed(input1), "Input1 should have been executed"
|
||||
assert result.was_executed(output_node1), "Output node 1 should have been executed (dependency)"
|
||||
assert result.was_executed(output_node2), "Output node 2 should have been executed (dependency)"
|
||||
assert result.was_executed(save2), "Save2 should have been executed"
|
||||
|
||||
# Should NOT run: input2, save1, save3
|
||||
assert not result.did_run(input2), "Input2 should not have run"
|
||||
assert not result.did_run(save1), "Save1 should not have run"
|
||||
assert not result.did_run(save3), "Save3 should not have run"
|
||||
|
||||
# Empty partial execution list (should execute nothing)
|
||||
def test_partial_execution_empty_list(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
_output1 = g.node("SaveImage", images=input1.out(0))
|
||||
|
||||
# Run with empty partial execution list
|
||||
try:
|
||||
_result = client.run(g, partial_execution_targets=[])
|
||||
# Should get an error because no outputs are selected
|
||||
assert False, "Should have raised an error for empty partial execution list"
|
||||
except urllib.error.HTTPError:
|
||||
pass # Expected behavior
|
||||
|
233
tests/execution/test_progress_isolation.py
Normal file
233
tests/execution/test_progress_isolation.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Test that progress updates are properly isolated between WebSocket clients."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
import time
|
||||
import threading
|
||||
import uuid
|
||||
import websocket
|
||||
from typing import List, Dict, Any
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from tests.execution.test_execution import ComfyClient
|
||||
|
||||
|
||||
class ProgressTracker:
|
||||
"""Tracks progress messages received by a WebSocket client."""
|
||||
|
||||
def __init__(self, client_id: str):
|
||||
self.client_id = client_id
|
||||
self.progress_messages: List[Dict[str, Any]] = []
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def add_message(self, message: Dict[str, Any]):
|
||||
"""Thread-safe addition of progress messages."""
|
||||
with self.lock:
|
||||
self.progress_messages.append(message)
|
||||
|
||||
def get_messages_for_prompt(self, prompt_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all progress messages for a specific prompt_id."""
|
||||
with self.lock:
|
||||
return [
|
||||
msg for msg in self.progress_messages
|
||||
if msg.get('data', {}).get('prompt_id') == prompt_id
|
||||
]
|
||||
|
||||
def has_cross_contamination(self, own_prompt_id: str) -> bool:
|
||||
"""Check if this client received progress for other prompts."""
|
||||
with self.lock:
|
||||
for msg in self.progress_messages:
|
||||
msg_prompt_id = msg.get('data', {}).get('prompt_id')
|
||||
if msg_prompt_id and msg_prompt_id != own_prompt_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class IsolatedClient(ComfyClient):
|
||||
"""Extended ComfyClient that tracks all WebSocket messages."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.progress_tracker = None
|
||||
self.all_messages: List[Dict[str, Any]] = []
|
||||
|
||||
def connect(self, listen='127.0.0.1', port=8188, client_id=None):
|
||||
"""Connect with a specific client_id and set up message tracking."""
|
||||
if client_id is None:
|
||||
client_id = str(uuid.uuid4())
|
||||
super().connect(listen, port, client_id)
|
||||
self.progress_tracker = ProgressTracker(client_id)
|
||||
|
||||
def listen_for_messages(self, duration: float = 5.0):
|
||||
"""Listen for WebSocket messages for a specified duration."""
|
||||
end_time = time.time() + duration
|
||||
self.ws.settimeout(0.5) # Non-blocking with timeout
|
||||
|
||||
while time.time() < end_time:
|
||||
try:
|
||||
out = self.ws.recv()
|
||||
if isinstance(out, str):
|
||||
message = json.loads(out)
|
||||
self.all_messages.append(message)
|
||||
|
||||
# Track progress_state messages
|
||||
if message.get('type') == 'progress_state':
|
||||
self.progress_tracker.add_message(message)
|
||||
except websocket.WebSocketTimeoutException:
|
||||
continue
|
||||
except Exception:
|
||||
# Log error silently in test context
|
||||
break
|
||||
|
||||
|
||||
@pytest.mark.execution
|
||||
class TestProgressIsolation:
|
||||
"""Test suite for verifying progress update isolation between clients."""
|
||||
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def _server(self, args_pytest):
|
||||
"""Start the ComfyUI server for testing."""
|
||||
import subprocess
|
||||
pargs = [
|
||||
'python', 'main.py',
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
|
||||
'--cpu',
|
||||
]
|
||||
p = subprocess.Popen(pargs)
|
||||
yield
|
||||
p.kill()
|
||||
|
||||
def start_client_with_retry(self, listen: str, port: int, client_id: str = None):
|
||||
"""Start client with connection retries."""
|
||||
client = IsolatedClient()
|
||||
# Connect to server (with retries)
|
||||
n_tries = 5
|
||||
for i in range(n_tries):
|
||||
time.sleep(4)
|
||||
try:
|
||||
client.connect(listen, port, client_id)
|
||||
return client
|
||||
except ConnectionRefusedError as e:
|
||||
print(e) # noqa: T201
|
||||
print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201
|
||||
raise ConnectionRefusedError(f"Failed to connect after {n_tries} attempts")
|
||||
|
||||
def test_progress_isolation_between_clients(self, args_pytest):
|
||||
"""Test that progress updates are isolated between different clients."""
|
||||
listen = args_pytest["listen"]
|
||||
port = args_pytest["port"]
|
||||
|
||||
# Create two separate clients with unique IDs
|
||||
client_a_id = "client_a_" + str(uuid.uuid4())
|
||||
client_b_id = "client_b_" + str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Connect both clients with retries
|
||||
client_a = self.start_client_with_retry(listen, port, client_a_id)
|
||||
client_b = self.start_client_with_retry(listen, port, client_b_id)
|
||||
|
||||
# Create simple workflows for both clients
|
||||
graph_a = GraphBuilder(prefix="client_a")
|
||||
image_a = graph_a.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
|
||||
graph_a.node("PreviewImage", images=image_a.out(0))
|
||||
|
||||
graph_b = GraphBuilder(prefix="client_b")
|
||||
image_b = graph_b.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
|
||||
graph_b.node("PreviewImage", images=image_b.out(0))
|
||||
|
||||
# Submit workflows from both clients
|
||||
prompt_a = graph_a.finalize()
|
||||
prompt_b = graph_b.finalize()
|
||||
|
||||
response_a = client_a.queue_prompt(prompt_a)
|
||||
prompt_id_a = response_a['prompt_id']
|
||||
|
||||
response_b = client_b.queue_prompt(prompt_b)
|
||||
prompt_id_b = response_b['prompt_id']
|
||||
|
||||
# Start threads to listen for messages on both clients
|
||||
def listen_client_a():
|
||||
client_a.listen_for_messages(duration=10.0)
|
||||
|
||||
def listen_client_b():
|
||||
client_b.listen_for_messages(duration=10.0)
|
||||
|
||||
thread_a = threading.Thread(target=listen_client_a)
|
||||
thread_b = threading.Thread(target=listen_client_b)
|
||||
|
||||
thread_a.start()
|
||||
thread_b.start()
|
||||
|
||||
# Wait for threads to complete
|
||||
thread_a.join()
|
||||
thread_b.join()
|
||||
|
||||
# Verify isolation
|
||||
# Client A should only receive progress for prompt_id_a
|
||||
assert not client_a.progress_tracker.has_cross_contamination(prompt_id_a), \
|
||||
f"Client A received progress updates for other clients' workflows. " \
|
||||
f"Expected only {prompt_id_a}, but got messages for multiple prompts."
|
||||
|
||||
# Client B should only receive progress for prompt_id_b
|
||||
assert not client_b.progress_tracker.has_cross_contamination(prompt_id_b), \
|
||||
f"Client B received progress updates for other clients' workflows. " \
|
||||
f"Expected only {prompt_id_b}, but got messages for multiple prompts."
|
||||
|
||||
# Verify each client received their own progress updates
|
||||
client_a_messages = client_a.progress_tracker.get_messages_for_prompt(prompt_id_a)
|
||||
client_b_messages = client_b.progress_tracker.get_messages_for_prompt(prompt_id_b)
|
||||
|
||||
assert len(client_a_messages) > 0, \
|
||||
"Client A did not receive any progress updates for its own workflow"
|
||||
assert len(client_b_messages) > 0, \
|
||||
"Client B did not receive any progress updates for its own workflow"
|
||||
|
||||
# Ensure no cross-contamination
|
||||
client_a_other = client_a.progress_tracker.get_messages_for_prompt(prompt_id_b)
|
||||
client_b_other = client_b.progress_tracker.get_messages_for_prompt(prompt_id_a)
|
||||
|
||||
assert len(client_a_other) == 0, \
|
||||
f"Client A incorrectly received {len(client_a_other)} progress updates for Client B's workflow"
|
||||
assert len(client_b_other) == 0, \
|
||||
f"Client B incorrectly received {len(client_b_other)} progress updates for Client A's workflow"
|
||||
|
||||
finally:
|
||||
# Clean up connections
|
||||
if hasattr(client_a, 'ws'):
|
||||
client_a.ws.close()
|
||||
if hasattr(client_b, 'ws'):
|
||||
client_b.ws.close()
|
||||
|
||||
def test_progress_with_missing_client_id(self, args_pytest):
|
||||
"""Test that progress updates handle missing client_id gracefully."""
|
||||
listen = args_pytest["listen"]
|
||||
port = args_pytest["port"]
|
||||
|
||||
try:
|
||||
# Connect client with retries
|
||||
client = self.start_client_with_retry(listen, port)
|
||||
|
||||
# Create a simple workflow
|
||||
graph = GraphBuilder(prefix="test_missing_id")
|
||||
image = graph.node("StubImage", content="BLACK", height=128, width=128, batch_size=1)
|
||||
graph.node("PreviewImage", images=image.out(0))
|
||||
|
||||
# Submit workflow
|
||||
prompt = graph.finalize()
|
||||
response = client.queue_prompt(prompt)
|
||||
prompt_id = response['prompt_id']
|
||||
|
||||
# Listen for messages
|
||||
client.listen_for_messages(duration=5.0)
|
||||
|
||||
# Should still receive progress updates for own workflow
|
||||
messages = client.progress_tracker.get_messages_for_prompt(prompt_id)
|
||||
assert len(messages) > 0, \
|
||||
"Client did not receive progress updates even though it initiated the workflow"
|
||||
|
||||
finally:
|
||||
if hasattr(client, 'ws'):
|
||||
client.ws.close()
|
||||
|
28
tests/execution/testing_nodes/testing-pack/__init__.py
Normal file
28
tests/execution/testing_nodes/testing-pack/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from .specific_tests import TEST_NODE_CLASS_MAPPINGS, TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .api_test_nodes import API_TEST_NODE_CLASS_MAPPINGS, API_TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
|
||||
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
NODE_CLASS_MAPPINGS.update(TEST_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(API_TEST_NODE_CLASS_MAPPINGS)
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(API_TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
78
tests/execution/testing_nodes/testing-pack/api_test_nodes.py
Normal file
78
tests/execution/testing_nodes/testing-pack/api_test_nodes.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import asyncio
|
||||
import time
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
from comfy_api.v0_0_2 import ComfyAPI, ComfyAPISync
|
||||
|
||||
api = ComfyAPI()
|
||||
api_sync = ComfyAPISync()
|
||||
|
||||
|
||||
class TestAsyncProgressUpdate(ComfyNodeABC):
|
||||
"""Test node with async VALIDATE_INPUTS."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"sleep_seconds": (IO.FLOAT, {"default": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def execute(self, value, sleep_seconds):
|
||||
start = time.time()
|
||||
expiration = start + sleep_seconds
|
||||
now = start
|
||||
while now < expiration:
|
||||
now = time.time()
|
||||
await api.execution.set_progress(
|
||||
value=(now - start) / sleep_seconds,
|
||||
max_value=1.0,
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
return (value,)
|
||||
|
||||
|
||||
class TestSyncProgressUpdate(ComfyNodeABC):
|
||||
"""Test node with async VALIDATE_INPUTS."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"sleep_seconds": (IO.FLOAT, {"default": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
def execute(self, value, sleep_seconds):
|
||||
start = time.time()
|
||||
expiration = start + sleep_seconds
|
||||
now = start
|
||||
while now < expiration:
|
||||
now = time.time()
|
||||
api_sync.execution.set_progress(
|
||||
value=(now - start) / sleep_seconds,
|
||||
max_value=1.0,
|
||||
)
|
||||
time.sleep(0.01)
|
||||
return (value,)
|
||||
|
||||
|
||||
API_TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestAsyncProgressUpdate": TestAsyncProgressUpdate,
|
||||
"TestSyncProgressUpdate": TestSyncProgressUpdate,
|
||||
}
|
||||
|
||||
API_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestAsyncProgressUpdate": "Async Progress Update Test Node",
|
||||
"TestSyncProgressUpdate": "Sync Progress Update Test Node",
|
||||
}
|
343
tests/execution/testing_nodes/testing-pack/async_test_nodes.py
Normal file
343
tests/execution/testing_nodes/testing-pack/async_test_nodes.py
Normal file
@@ -0,0 +1,343 @@
|
||||
import torch
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
|
||||
|
||||
class TestAsyncValidation(ComfyNodeABC):
|
||||
"""Test node with async VALIDATE_INPUTS."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 5.0}),
|
||||
"threshold": ("FLOAT", {"default": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
@classmethod
|
||||
async def VALIDATE_INPUTS(cls, value, threshold):
|
||||
# Simulate async validation (e.g., checking remote service)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
if value > threshold:
|
||||
return f"Value {value} exceeds threshold {threshold}"
|
||||
return True
|
||||
|
||||
def process(self, value, threshold):
|
||||
# Create image based on value
|
||||
intensity = value / 10.0
|
||||
image = torch.ones([1, 512, 512, 3]) * intensity
|
||||
return (image,)
|
||||
|
||||
|
||||
class TestAsyncError(ComfyNodeABC):
|
||||
"""Test node that errors during async execution."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"error_after": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "error_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def error_execution(self, value, error_after):
|
||||
await asyncio.sleep(error_after)
|
||||
raise RuntimeError("Intentional async execution error for testing")
|
||||
|
||||
|
||||
class TestAsyncValidationError(ComfyNodeABC):
|
||||
"""Test node with async validation that always fails."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 5.0}),
|
||||
"max_value": ("FLOAT", {"default": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
@classmethod
|
||||
async def VALIDATE_INPUTS(cls, value, max_value):
|
||||
await asyncio.sleep(0.05)
|
||||
# Always fail validation for values > max_value
|
||||
if value > max_value:
|
||||
return f"Async validation failed: {value} > {max_value}"
|
||||
return True
|
||||
|
||||
def process(self, value, max_value):
|
||||
# This won't be reached if validation fails
|
||||
image = torch.ones([1, 512, 512, 3]) * (value / max_value)
|
||||
return (image,)
|
||||
|
||||
|
||||
class TestAsyncTimeout(ComfyNodeABC):
|
||||
"""Test node that simulates timeout scenarios."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"timeout": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}),
|
||||
"operation_time": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "timeout_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def timeout_execution(self, value, timeout, operation_time):
|
||||
try:
|
||||
# This will timeout if operation_time > timeout
|
||||
await asyncio.wait_for(asyncio.sleep(operation_time), timeout=timeout)
|
||||
return (value,)
|
||||
except asyncio.TimeoutError:
|
||||
raise RuntimeError(f"Operation timed out after {timeout} seconds")
|
||||
|
||||
|
||||
class TestSyncError(ComfyNodeABC):
|
||||
"""Test node that errors synchronously (for mixed sync/async testing)."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "sync_error"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
def sync_error(self, value):
|
||||
raise RuntimeError("Intentional sync execution error for testing")
|
||||
|
||||
|
||||
class TestAsyncLazyCheck(ComfyNodeABC):
|
||||
"""Test node with async check_lazy_status."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": (IO.ANY, {"lazy": True}),
|
||||
"input2": (IO.ANY, {"lazy": True}),
|
||||
"condition": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def check_lazy_status(self, condition, input1, input2):
|
||||
# Simulate async checking (e.g., querying remote service)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
needed = []
|
||||
if condition and input1 is None:
|
||||
needed.append("input1")
|
||||
if not condition and input2 is None:
|
||||
needed.append("input2")
|
||||
return needed
|
||||
|
||||
def process(self, input1, input2, condition):
|
||||
# Return a simple image
|
||||
return (torch.ones([1, 512, 512, 3]),)
|
||||
|
||||
|
||||
class TestDynamicAsyncGeneration(ComfyNodeABC):
|
||||
"""Test node that dynamically generates async nodes."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",),
|
||||
"image2": ("IMAGE",),
|
||||
"num_async_nodes": ("INT", {"default": 3, "min": 1, "max": 10}),
|
||||
"sleep_duration": ("FLOAT", {"default": 0.2, "min": 0.1, "max": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "generate_async_workflow"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create multiple async sleep nodes
|
||||
sleep_nodes = []
|
||||
for i in range(num_async_nodes):
|
||||
image = image1 if i % 2 == 0 else image2
|
||||
sleep_node = g.node("TestSleep", value=image, seconds=sleep_duration)
|
||||
sleep_nodes.append(sleep_node)
|
||||
|
||||
# Average all results
|
||||
if len(sleep_nodes) == 1:
|
||||
final_node = sleep_nodes[0]
|
||||
else:
|
||||
avg_inputs = {"input1": sleep_nodes[0].out(0)}
|
||||
for i, node in enumerate(sleep_nodes[1:], 2):
|
||||
avg_inputs[f"input{i}"] = node.out(0)
|
||||
final_node = g.node("TestVariadicAverage", **avg_inputs)
|
||||
|
||||
return {
|
||||
"result": (final_node.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
|
||||
class TestAsyncResourceUser(ComfyNodeABC):
|
||||
"""Test node that uses resources during async execution."""
|
||||
|
||||
# Class-level resource tracking for testing
|
||||
_active_resources: Dict[str, bool] = {}
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"resource_id": ("STRING", {"default": "resource_0"}),
|
||||
"duration": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "use_resource"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def use_resource(self, value, resource_id, duration):
|
||||
# Check if resource is already in use
|
||||
if self._active_resources.get(resource_id, False):
|
||||
raise RuntimeError(f"Resource {resource_id} is already in use!")
|
||||
|
||||
# Mark resource as in use
|
||||
self._active_resources[resource_id] = True
|
||||
|
||||
try:
|
||||
# Simulate resource usage
|
||||
await asyncio.sleep(duration)
|
||||
return (value,)
|
||||
finally:
|
||||
# Always clean up resource
|
||||
self._active_resources[resource_id] = False
|
||||
|
||||
|
||||
class TestAsyncBatchProcessing(ComfyNodeABC):
|
||||
"""Test async processing of batched inputs."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
"process_time_per_item": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 1.0}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process_batch"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def process_batch(self, images, process_time_per_item, unique_id):
|
||||
batch_size = images.shape[0]
|
||||
pbar = ProgressBar(batch_size, node_id=unique_id)
|
||||
|
||||
# Process each image in the batch
|
||||
processed = []
|
||||
for i in range(batch_size):
|
||||
# Simulate async processing
|
||||
await asyncio.sleep(process_time_per_item)
|
||||
|
||||
# Simple processing: invert the image
|
||||
processed_image = 1.0 - images[i:i+1]
|
||||
processed.append(processed_image)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
# Stack processed images
|
||||
result = torch.cat(processed, dim=0)
|
||||
return (result,)
|
||||
|
||||
|
||||
class TestAsyncConcurrentLimit(ComfyNodeABC):
|
||||
"""Test concurrent execution limits for async nodes."""
|
||||
|
||||
_semaphore = asyncio.Semaphore(2) # Only allow 2 concurrent executions
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"duration": ("FLOAT", {"default": 0.5, "min": 0.1, "max": 2.0}),
|
||||
"node_id": ("INT", {"default": 0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "limited_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def limited_execution(self, value, duration, node_id):
|
||||
async with self._semaphore:
|
||||
# Node {node_id} acquired semaphore
|
||||
await asyncio.sleep(duration)
|
||||
# Node {node_id} releasing semaphore
|
||||
return (value,)
|
||||
|
||||
|
||||
# Add node mappings
|
||||
ASYNC_TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestAsyncValidation": TestAsyncValidation,
|
||||
"TestAsyncError": TestAsyncError,
|
||||
"TestAsyncValidationError": TestAsyncValidationError,
|
||||
"TestAsyncTimeout": TestAsyncTimeout,
|
||||
"TestSyncError": TestSyncError,
|
||||
"TestAsyncLazyCheck": TestAsyncLazyCheck,
|
||||
"TestDynamicAsyncGeneration": TestDynamicAsyncGeneration,
|
||||
"TestAsyncResourceUser": TestAsyncResourceUser,
|
||||
"TestAsyncBatchProcessing": TestAsyncBatchProcessing,
|
||||
"TestAsyncConcurrentLimit": TestAsyncConcurrentLimit,
|
||||
}
|
||||
|
||||
ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestAsyncValidation": "Test Async Validation",
|
||||
"TestAsyncError": "Test Async Error",
|
||||
"TestAsyncValidationError": "Test Async Validation Error",
|
||||
"TestAsyncTimeout": "Test Async Timeout",
|
||||
"TestSyncError": "Test Sync Error",
|
||||
"TestAsyncLazyCheck": "Test Async Lazy Check",
|
||||
"TestDynamicAsyncGeneration": "Test Dynamic Async Generation",
|
||||
"TestAsyncResourceUser": "Test Async Resource User",
|
||||
"TestAsyncBatchProcessing": "Test Async Batch Processing",
|
||||
"TestAsyncConcurrentLimit": "Test Async Concurrent Limit",
|
||||
}
|
194
tests/execution/testing_nodes/testing-pack/conditions.py
Normal file
194
tests/execution/testing_nodes/testing-pack/conditions.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import re
|
||||
import torch
|
||||
|
||||
class TestIntConditions:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
|
||||
"b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
|
||||
"operation": (["==", "!=", "<", ">", "<=", ">="],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOOLEAN",)
|
||||
FUNCTION = "int_condition"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def int_condition(self, a, b, operation):
|
||||
if operation == "==":
|
||||
return (a == b,)
|
||||
elif operation == "!=":
|
||||
return (a != b,)
|
||||
elif operation == "<":
|
||||
return (a < b,)
|
||||
elif operation == ">":
|
||||
return (a > b,)
|
||||
elif operation == "<=":
|
||||
return (a <= b,)
|
||||
elif operation == ">=":
|
||||
return (a >= b,)
|
||||
|
||||
|
||||
class TestFloatConditions:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"a": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}),
|
||||
"b": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}),
|
||||
"operation": (["==", "!=", "<", ">", "<=", ">="],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOOLEAN",)
|
||||
FUNCTION = "float_condition"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def float_condition(self, a, b, operation):
|
||||
if operation == "==":
|
||||
return (a == b,)
|
||||
elif operation == "!=":
|
||||
return (a != b,)
|
||||
elif operation == "<":
|
||||
return (a < b,)
|
||||
elif operation == ">":
|
||||
return (a > b,)
|
||||
elif operation == "<=":
|
||||
return (a <= b,)
|
||||
elif operation == ">=":
|
||||
return (a >= b,)
|
||||
|
||||
class TestStringConditions:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"a": ("STRING", {"multiline": False}),
|
||||
"b": ("STRING", {"multiline": False}),
|
||||
"operation": (["a == b", "a != b", "a IN b", "a MATCH REGEX(b)", "a BEGINSWITH b", "a ENDSWITH b"],),
|
||||
"case_sensitive": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOOLEAN",)
|
||||
FUNCTION = "string_condition"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def string_condition(self, a, b, operation, case_sensitive):
|
||||
if not case_sensitive:
|
||||
a = a.lower()
|
||||
b = b.lower()
|
||||
|
||||
if operation == "a == b":
|
||||
return (a == b,)
|
||||
elif operation == "a != b":
|
||||
return (a != b,)
|
||||
elif operation == "a IN b":
|
||||
return (a in b,)
|
||||
elif operation == "a MATCH REGEX(b)":
|
||||
try:
|
||||
return (re.match(b, a) is not None,)
|
||||
except:
|
||||
return (False,)
|
||||
elif operation == "a BEGINSWITH b":
|
||||
return (a.startswith(b),)
|
||||
elif operation == "a ENDSWITH b":
|
||||
return (a.endswith(b),)
|
||||
|
||||
class TestToBoolNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("*",),
|
||||
},
|
||||
"optional": {
|
||||
"invert": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOOLEAN",)
|
||||
FUNCTION = "to_bool"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def to_bool(self, value, invert = False):
|
||||
if isinstance(value, torch.Tensor):
|
||||
if value.max().item() == 0 and value.min().item() == 0:
|
||||
result = False
|
||||
else:
|
||||
result = True
|
||||
else:
|
||||
try:
|
||||
result = bool(value)
|
||||
except:
|
||||
# Can't convert it? Well then it's something or other. I dunno, I'm not a Python programmer.
|
||||
result = True
|
||||
|
||||
if invert:
|
||||
result = not result
|
||||
|
||||
return (result,)
|
||||
|
||||
class TestBoolOperationNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"a": ("BOOLEAN",),
|
||||
"b": ("BOOLEAN",),
|
||||
"op": (["a AND b", "a OR b", "a XOR b", "NOT a"],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOOLEAN",)
|
||||
FUNCTION = "bool_operation"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def bool_operation(self, a, b, op):
|
||||
if op == "a AND b":
|
||||
return (a and b,)
|
||||
elif op == "a OR b":
|
||||
return (a or b,)
|
||||
elif op == "a XOR b":
|
||||
return (a ^ b,)
|
||||
elif op == "NOT a":
|
||||
return (not a,)
|
||||
|
||||
|
||||
CONDITION_NODE_CLASS_MAPPINGS = {
|
||||
"TestIntConditions": TestIntConditions,
|
||||
"TestFloatConditions": TestFloatConditions,
|
||||
"TestStringConditions": TestStringConditions,
|
||||
"TestToBoolNode": TestToBoolNode,
|
||||
"TestBoolOperationNode": TestBoolOperationNode,
|
||||
}
|
||||
|
||||
CONDITION_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestIntConditions": "Int Condition",
|
||||
"TestFloatConditions": "Float Condition",
|
||||
"TestStringConditions": "String Condition",
|
||||
"TestToBoolNode": "To Bool",
|
||||
"TestBoolOperationNode": "Bool Operation",
|
||||
}
|
173
tests/execution/testing_nodes/testing-pack/flow_control.py
Normal file
173
tests/execution/testing_nodes/testing-pack/flow_control.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.graph import ExecutionBlocker
|
||||
from .tools import VariantSupport
|
||||
|
||||
NUM_FLOW_SOCKETS = 5
|
||||
@VariantSupport()
|
||||
class TestWhileLoopOpen:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
inputs = {
|
||||
"required": {
|
||||
"condition": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
}
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
inputs["optional"][f"initial_value{i}"] = ("*",)
|
||||
return inputs
|
||||
|
||||
RETURN_TYPES = tuple(["FLOW_CONTROL"] + ["*"] * NUM_FLOW_SOCKETS)
|
||||
RETURN_NAMES = tuple(["FLOW_CONTROL"] + [f"value{i}" for i in range(NUM_FLOW_SOCKETS)])
|
||||
FUNCTION = "while_loop_open"
|
||||
|
||||
CATEGORY = "Testing/Flow"
|
||||
|
||||
def while_loop_open(self, condition, **kwargs):
|
||||
values = []
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
values.append(kwargs.get(f"initial_value{i}", None))
|
||||
return tuple(["stub"] + values)
|
||||
|
||||
@VariantSupport()
|
||||
class TestWhileLoopClose:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
inputs = {
|
||||
"required": {
|
||||
"flow_control": ("FLOW_CONTROL", {"rawLink": True}),
|
||||
"condition": ("BOOLEAN", {"forceInput": True}),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"hidden": {
|
||||
"dynprompt": "DYNPROMPT",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
}
|
||||
}
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
inputs["optional"][f"initial_value{i}"] = ("*",)
|
||||
return inputs
|
||||
|
||||
RETURN_TYPES = tuple(["*"] * NUM_FLOW_SOCKETS)
|
||||
RETURN_NAMES = tuple([f"value{i}" for i in range(NUM_FLOW_SOCKETS)])
|
||||
FUNCTION = "while_loop_close"
|
||||
|
||||
CATEGORY = "Testing/Flow"
|
||||
|
||||
def explore_dependencies(self, node_id, dynprompt, upstream):
|
||||
node_info = dynprompt.get_node(node_id)
|
||||
if "inputs" not in node_info:
|
||||
return
|
||||
for k, v in node_info["inputs"].items():
|
||||
if is_link(v):
|
||||
parent_id = v[0]
|
||||
if parent_id not in upstream:
|
||||
upstream[parent_id] = []
|
||||
self.explore_dependencies(parent_id, dynprompt, upstream)
|
||||
upstream[parent_id].append(node_id)
|
||||
|
||||
def collect_contained(self, node_id, upstream, contained):
|
||||
if node_id not in upstream:
|
||||
return
|
||||
for child_id in upstream[node_id]:
|
||||
if child_id not in contained:
|
||||
contained[child_id] = True
|
||||
self.collect_contained(child_id, upstream, contained)
|
||||
|
||||
|
||||
def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=None, **kwargs):
|
||||
assert dynprompt is not None
|
||||
if not condition:
|
||||
# We're done with the loop
|
||||
values = []
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
values.append(kwargs.get(f"initial_value{i}", None))
|
||||
return tuple(values)
|
||||
|
||||
# We want to loop
|
||||
upstream = {}
|
||||
# Get the list of all nodes between the open and close nodes
|
||||
self.explore_dependencies(unique_id, dynprompt, upstream)
|
||||
|
||||
contained = {}
|
||||
open_node = flow_control[0]
|
||||
self.collect_contained(open_node, upstream, contained)
|
||||
contained[unique_id] = True
|
||||
contained[open_node] = True
|
||||
|
||||
# We'll use the default prefix, but to avoid having node names grow exponentially in size,
|
||||
# we'll use "Recurse" for the name of the recursively-generated copy of this node.
|
||||
graph = GraphBuilder()
|
||||
for node_id in contained:
|
||||
original_node = dynprompt.get_node(node_id)
|
||||
node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id)
|
||||
node.set_override_display_id(node_id)
|
||||
for node_id in contained:
|
||||
original_node = dynprompt.get_node(node_id)
|
||||
node = graph.lookup_node("Recurse" if node_id == unique_id else node_id)
|
||||
assert node is not None
|
||||
for k, v in original_node["inputs"].items():
|
||||
if is_link(v) and v[0] in contained:
|
||||
parent = graph.lookup_node(v[0])
|
||||
assert parent is not None
|
||||
node.set_input(k, parent.out(v[1]))
|
||||
else:
|
||||
node.set_input(k, v)
|
||||
new_open = graph.lookup_node(open_node)
|
||||
assert new_open is not None
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
key = f"initial_value{i}"
|
||||
new_open.set_input(key, kwargs.get(key, None))
|
||||
my_clone = graph.lookup_node("Recurse")
|
||||
assert my_clone is not None
|
||||
result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS))
|
||||
return {
|
||||
"result": tuple(result),
|
||||
"expand": graph.finalize(),
|
||||
}
|
||||
|
||||
@VariantSupport()
|
||||
class TestExecutionBlockerNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
inputs = {
|
||||
"required": {
|
||||
"input": ("*",),
|
||||
"block": ("BOOLEAN",),
|
||||
"verbose": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
}
|
||||
return inputs
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
RETURN_NAMES = ("output",)
|
||||
FUNCTION = "execution_blocker"
|
||||
|
||||
CATEGORY = "Testing/Flow"
|
||||
|
||||
def execution_blocker(self, input, block, verbose):
|
||||
if block:
|
||||
return (ExecutionBlocker("Blocked Execution" if verbose else None),)
|
||||
return (input,)
|
||||
|
||||
FLOW_CONTROL_NODE_CLASS_MAPPINGS = {
|
||||
"TestWhileLoopOpen": TestWhileLoopOpen,
|
||||
"TestWhileLoopClose": TestWhileLoopClose,
|
||||
"TestExecutionBlocker": TestExecutionBlockerNode,
|
||||
}
|
||||
FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestWhileLoopOpen": "While Loop Open",
|
||||
"TestWhileLoopClose": "While Loop Close",
|
||||
"TestExecutionBlocker": "Execution Blocker",
|
||||
}
|
519
tests/execution/testing_nodes/testing-pack/specific_tests.py
Normal file
519
tests/execution/testing_nodes/testing-pack/specific_tests.py
Normal file
@@ -0,0 +1,519 @@
|
||||
import torch
|
||||
import time
|
||||
import asyncio
|
||||
from comfy.utils import ProgressBar
|
||||
from .tools import VariantSupport
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
|
||||
class TestLazyMixImages:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",{"lazy": True}),
|
||||
"image2": ("IMAGE",{"lazy": True}),
|
||||
"mask": ("MASK",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "mix"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def check_lazy_status(self, mask, image1, image2):
|
||||
mask_min = mask.min()
|
||||
mask_max = mask.max()
|
||||
needed = []
|
||||
if image1 is None and (mask_min != 1.0 or mask_max != 1.0):
|
||||
needed.append("image1")
|
||||
if image2 is None and (mask_min != 0.0 or mask_max != 0.0):
|
||||
needed.append("image2")
|
||||
return needed
|
||||
|
||||
# Not trying to handle different batch sizes here just to keep the demo simple
|
||||
def mix(self, mask, image1, image2):
|
||||
mask_min = mask.min()
|
||||
mask_max = mask.max()
|
||||
if mask_min == 0.0 and mask_max == 0.0:
|
||||
return (image1,)
|
||||
elif mask_min == 1.0 and mask_max == 1.0:
|
||||
return (image2,)
|
||||
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if len(mask.shape) == 3:
|
||||
mask = mask.unsqueeze(3)
|
||||
if mask.shape[3] < image1.shape[3]:
|
||||
mask = mask.repeat(1, 1, 1, image1.shape[3])
|
||||
|
||||
result = image1 * (1. - mask) + image2 * mask,
|
||||
return (result[0],)
|
||||
|
||||
class TestVariadicAverage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "variadic_average"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def variadic_average(self, input1, **kwargs):
|
||||
inputs = [input1]
|
||||
while 'input' + str(len(inputs) + 1) in kwargs:
|
||||
inputs.append(kwargs['input' + str(len(inputs) + 1)])
|
||||
return (torch.stack(inputs).mean(dim=0),)
|
||||
|
||||
|
||||
class TestCustomIsChanged:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
"optional": {
|
||||
"should_change": ("BOOL", {"default": False}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_is_changed"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_is_changed(self, image, should_change=False):
|
||||
return (image,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, should_change=False, *args, **kwargs):
|
||||
if should_change:
|
||||
return float("NaN")
|
||||
else:
|
||||
return False
|
||||
|
||||
class TestIsChangedWithConstants:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_is_changed"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_is_changed(self, image, value):
|
||||
return (image * value,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, image, value):
|
||||
if image is None:
|
||||
return value
|
||||
else:
|
||||
return image.mean().item() * value
|
||||
|
||||
class TestCustomValidation1:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("IMAGE,FLOAT",),
|
||||
"input2": ("IMAGE,FLOAT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_validation1"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_validation1(self, input1, input2):
|
||||
if isinstance(input1, float) and isinstance(input2, float):
|
||||
result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
||||
else:
|
||||
result = input1 * input2
|
||||
return (result,)
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, input1=None, input2=None):
|
||||
if input1 is not None:
|
||||
if not isinstance(input1, (torch.Tensor, float)):
|
||||
return f"Invalid type of input1: {type(input1)}"
|
||||
if input2 is not None:
|
||||
if not isinstance(input2, (torch.Tensor, float)):
|
||||
return f"Invalid type of input2: {type(input2)}"
|
||||
|
||||
return True
|
||||
|
||||
class TestCustomValidation2:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("IMAGE,FLOAT",),
|
||||
"input2": ("IMAGE,FLOAT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_validation2"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_validation2(self, input1, input2):
|
||||
if isinstance(input1, float) and isinstance(input2, float):
|
||||
result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
||||
else:
|
||||
result = input1 * input2
|
||||
return (result,)
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None):
|
||||
if input1 is not None:
|
||||
if not isinstance(input1, (torch.Tensor, float)):
|
||||
return f"Invalid type of input1: {type(input1)}"
|
||||
if input2 is not None:
|
||||
if not isinstance(input2, (torch.Tensor, float)):
|
||||
return f"Invalid type of input2: {type(input2)}"
|
||||
|
||||
if 'input1' in input_types:
|
||||
if input_types['input1'] not in ["IMAGE", "FLOAT"]:
|
||||
return f"Invalid type of input1: {input_types['input1']}"
|
||||
if 'input2' in input_types:
|
||||
if input_types['input2'] not in ["IMAGE", "FLOAT"]:
|
||||
return f"Invalid type of input2: {input_types['input2']}"
|
||||
|
||||
return True
|
||||
|
||||
@VariantSupport()
|
||||
class TestCustomValidation3:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("IMAGE,FLOAT",),
|
||||
"input2": ("IMAGE,FLOAT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_validation3"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_validation3(self, input1, input2):
|
||||
if isinstance(input1, float) and isinstance(input2, float):
|
||||
result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
||||
else:
|
||||
result = input1 * input2
|
||||
return (result,)
|
||||
|
||||
class TestCustomValidation4:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("FLOAT",),
|
||||
"input2": ("FLOAT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_validation4"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_validation4(self, input1, input2):
|
||||
result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
||||
return (result,)
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, input1, input2):
|
||||
if input1 is not None:
|
||||
if not isinstance(input1, float):
|
||||
return f"Invalid type of input1: {type(input1)}"
|
||||
if input2 is not None:
|
||||
if not isinstance(input2, float):
|
||||
return f"Invalid type of input2: {type(input2)}"
|
||||
|
||||
return True
|
||||
|
||||
class TestCustomValidation5:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("FLOAT", {"min": 0.0, "max": 1.0}),
|
||||
"input2": ("FLOAT", {"min": 0.0, "max": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_validation5"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_validation5(self, input1, input2):
|
||||
value = input1 * input2
|
||||
return (torch.ones([1, 512, 512, 3]) * value,)
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, **kwargs):
|
||||
if kwargs['input2'] == 7.0:
|
||||
return "7s are not allowed. I've never liked 7s."
|
||||
return True
|
||||
|
||||
class TestDynamicDependencyCycle:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("IMAGE",),
|
||||
"input2": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "dynamic_dependency_cycle"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def dynamic_dependency_cycle(self, input1, input2):
|
||||
g = GraphBuilder()
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
mix1 = g.node("TestLazyMixImages", image1=input1, mask=mask.out(0))
|
||||
mix2 = g.node("TestLazyMixImages", image1=mix1.out(0), image2=input2, mask=mask.out(0))
|
||||
|
||||
# Create the cyle
|
||||
mix1.set_input("image2", mix2.out(0))
|
||||
|
||||
return {
|
||||
"result": (mix2.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestMixedExpansionReturns:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("FLOAT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE","IMAGE")
|
||||
FUNCTION = "mixed_expansion_returns"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def mixed_expansion_returns(self, input1):
|
||||
white_image = torch.ones([1, 512, 512, 3])
|
||||
if input1 <= 0.1:
|
||||
return (torch.ones([1, 512, 512, 3]) * 0.1, white_image)
|
||||
elif input1 <= 0.2:
|
||||
return {
|
||||
"result": (torch.ones([1, 512, 512, 3]) * 0.2, white_image),
|
||||
}
|
||||
else:
|
||||
g = GraphBuilder()
|
||||
mask = g.node("StubMask", value=0.3, height=512, width=512, batch_size=1)
|
||||
black = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
white = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mix = g.node("TestLazyMixImages", image1=black.out(0), image2=white.out(0), mask=mask.out(0))
|
||||
return {
|
||||
"result": (mix.out(0), white_image),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestSamplingInExpansion:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"clip": ("CLIP",),
|
||||
"vae": ("VAE",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 100}),
|
||||
"cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 30.0}),
|
||||
"prompt": ("STRING", {"multiline": True, "default": "a beautiful landscape with mountains and trees"}),
|
||||
"negative_prompt": ("STRING", {"multiline": True, "default": "blurry, bad quality, worst quality"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "sampling_in_expansion"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def sampling_in_expansion(self, model, clip, vae, seed, steps, cfg, prompt, negative_prompt):
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create a basic image generation workflow using the input model, clip and vae
|
||||
# 1. Setup text prompts using the provided CLIP model
|
||||
positive_prompt = g.node("CLIPTextEncode",
|
||||
text=prompt,
|
||||
clip=clip)
|
||||
negative_prompt = g.node("CLIPTextEncode",
|
||||
text=negative_prompt,
|
||||
clip=clip)
|
||||
|
||||
# 2. Create empty latent with specified size
|
||||
empty_latent = g.node("EmptyLatentImage", width=512, height=512, batch_size=1)
|
||||
|
||||
# 3. Setup sampler and generate image latent
|
||||
sampler = g.node("KSampler",
|
||||
model=model,
|
||||
positive=positive_prompt.out(0),
|
||||
negative=negative_prompt.out(0),
|
||||
latent_image=empty_latent.out(0),
|
||||
seed=seed,
|
||||
steps=steps,
|
||||
cfg=cfg,
|
||||
sampler_name="euler_ancestral",
|
||||
scheduler="normal")
|
||||
|
||||
# 4. Decode latent to image using VAE
|
||||
output = g.node("VAEDecode", samples=sampler.out(0), vae=vae)
|
||||
|
||||
return {
|
||||
"result": (output.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestSleep(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "sleep"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
async def sleep(self, value, seconds, unique_id):
|
||||
pbar = ProgressBar(seconds, node_id=unique_id)
|
||||
start = time.time()
|
||||
expiration = start + seconds
|
||||
now = start
|
||||
while now < expiration:
|
||||
now = time.time()
|
||||
pbar.update_absolute(now - start)
|
||||
await asyncio.sleep(0.01)
|
||||
return (value,)
|
||||
|
||||
class TestParallelSleep(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE", ),
|
||||
"image2": ("IMAGE", ),
|
||||
"image3": ("IMAGE", ),
|
||||
"sleep1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"sleep2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"sleep3": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "parallel_sleep"
|
||||
CATEGORY = "_for_testing"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
|
||||
# Create a graph dynamically with three TestSleep nodes
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create sleep nodes for each duration and image
|
||||
sleep_node1 = g.node("TestSleep", value=image1, seconds=sleep1)
|
||||
sleep_node2 = g.node("TestSleep", value=image2, seconds=sleep2)
|
||||
sleep_node3 = g.node("TestSleep", value=image3, seconds=sleep3)
|
||||
|
||||
# Blend the results using TestVariadicAverage
|
||||
blend = g.node("TestVariadicAverage",
|
||||
input1=sleep_node1.out(0),
|
||||
input2=sleep_node2.out(0),
|
||||
input3=sleep_node3.out(0))
|
||||
|
||||
return {
|
||||
"result": (blend.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestOutputNodeWithSocketOutput:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def process(self, image, value):
|
||||
# Apply value scaling and return both as output and socket
|
||||
result = image * value
|
||||
return (result,)
|
||||
|
||||
TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestLazyMixImages": TestLazyMixImages,
|
||||
"TestVariadicAverage": TestVariadicAverage,
|
||||
"TestCustomIsChanged": TestCustomIsChanged,
|
||||
"TestIsChangedWithConstants": TestIsChangedWithConstants,
|
||||
"TestCustomValidation1": TestCustomValidation1,
|
||||
"TestCustomValidation2": TestCustomValidation2,
|
||||
"TestCustomValidation3": TestCustomValidation3,
|
||||
"TestCustomValidation4": TestCustomValidation4,
|
||||
"TestCustomValidation5": TestCustomValidation5,
|
||||
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
|
||||
"TestMixedExpansionReturns": TestMixedExpansionReturns,
|
||||
"TestSamplingInExpansion": TestSamplingInExpansion,
|
||||
"TestSleep": TestSleep,
|
||||
"TestParallelSleep": TestParallelSleep,
|
||||
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
|
||||
}
|
||||
|
||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestLazyMixImages": "Lazy Mix Images",
|
||||
"TestVariadicAverage": "Variadic Average",
|
||||
"TestCustomIsChanged": "Custom IsChanged",
|
||||
"TestIsChangedWithConstants": "IsChanged With Constants",
|
||||
"TestCustomValidation1": "Custom Validation 1",
|
||||
"TestCustomValidation2": "Custom Validation 2",
|
||||
"TestCustomValidation3": "Custom Validation 3",
|
||||
"TestCustomValidation4": "Custom Validation 4",
|
||||
"TestCustomValidation5": "Custom Validation 5",
|
||||
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
|
||||
"TestMixedExpansionReturns": "Mixed Expansion Returns",
|
||||
"TestSamplingInExpansion": "Sampling In Expansion",
|
||||
"TestSleep": "Test Sleep",
|
||||
"TestParallelSleep": "Test Parallel Sleep",
|
||||
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
|
||||
}
|
129
tests/execution/testing_nodes/testing-pack/stubs.py
Normal file
129
tests/execution/testing_nodes/testing-pack/stubs.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
|
||||
class StubImage:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"content": (['WHITE', 'BLACK', 'NOISE'],),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stub_image"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_image(self, content, height, width, batch_size):
|
||||
if content == "WHITE":
|
||||
return (torch.ones(batch_size, height, width, 3),)
|
||||
elif content == "BLACK":
|
||||
return (torch.zeros(batch_size, height, width, 3),)
|
||||
elif content == "NOISE":
|
||||
return (torch.rand(batch_size, height, width, 3),)
|
||||
|
||||
class StubConstantImage:
|
||||
def __init__(self):
|
||||
pass
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stub_constant_image"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_constant_image(self, value, height, width, batch_size):
|
||||
return (torch.ones(batch_size, height, width, 3) * value,)
|
||||
|
||||
class StubMask:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "stub_mask"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_mask(self, value, height, width, batch_size):
|
||||
return (torch.ones(batch_size, height, width) * value,)
|
||||
|
||||
class StubInt:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("INT", {"default": 0, "min": -0xffffffff, "max": 0xffffffff, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "stub_int"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_int(self, value):
|
||||
return (value,)
|
||||
|
||||
class StubFloat:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 0.0, "min": -1.0e38, "max": 1.0e38, "step": 0.01}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "stub_float"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_float(self, value):
|
||||
return (value,)
|
||||
|
||||
TEST_STUB_NODE_CLASS_MAPPINGS = {
|
||||
"StubImage": StubImage,
|
||||
"StubConstantImage": StubConstantImage,
|
||||
"StubMask": StubMask,
|
||||
"StubInt": StubInt,
|
||||
"StubFloat": StubFloat,
|
||||
}
|
||||
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StubImage": "Stub Image",
|
||||
"StubConstantImage": "Stub Constant Image",
|
||||
"StubMask": "Stub Mask",
|
||||
"StubInt": "Stub Int",
|
||||
"StubFloat": "Stub Float",
|
||||
}
|
53
tests/execution/testing_nodes/testing-pack/tools.py
Normal file
53
tests/execution/testing_nodes/testing-pack/tools.py
Normal file
@@ -0,0 +1,53 @@
|
||||
|
||||
def MakeSmartType(t):
|
||||
if isinstance(t, str):
|
||||
return SmartType(t)
|
||||
return t
|
||||
|
||||
class SmartType(str):
|
||||
def __ne__(self, other):
|
||||
if self == "*" or other == "*":
|
||||
return False
|
||||
selfset = set(self.split(','))
|
||||
otherset = set(other.split(','))
|
||||
return not selfset.issubset(otherset)
|
||||
|
||||
def VariantSupport():
|
||||
def decorator(cls):
|
||||
if hasattr(cls, "INPUT_TYPES"):
|
||||
old_input_types = getattr(cls, "INPUT_TYPES")
|
||||
def new_input_types(*args, **kwargs):
|
||||
types = old_input_types(*args, **kwargs)
|
||||
for category in ["required", "optional"]:
|
||||
if category not in types:
|
||||
continue
|
||||
for key, value in types[category].items():
|
||||
if isinstance(value, tuple):
|
||||
types[category][key] = (MakeSmartType(value[0]),) + value[1:]
|
||||
return types
|
||||
setattr(cls, "INPUT_TYPES", new_input_types)
|
||||
if hasattr(cls, "RETURN_TYPES"):
|
||||
old_return_types = cls.RETURN_TYPES
|
||||
setattr(cls, "RETURN_TYPES", tuple(MakeSmartType(x) for x in old_return_types))
|
||||
if hasattr(cls, "VALIDATE_INPUTS"):
|
||||
# Reflection is used to determine what the function signature is, so we can't just change the function signature
|
||||
raise NotImplementedError("VariantSupport does not support VALIDATE_INPUTS yet")
|
||||
else:
|
||||
def validate_inputs(input_types):
|
||||
inputs = cls.INPUT_TYPES()
|
||||
for key, value in input_types.items():
|
||||
if isinstance(value, SmartType):
|
||||
continue
|
||||
if "required" in inputs and key in inputs["required"]:
|
||||
expected_type = inputs["required"][key][0]
|
||||
elif "optional" in inputs and key in inputs["optional"]:
|
||||
expected_type = inputs["optional"][key][0]
|
||||
else:
|
||||
expected_type = None
|
||||
if expected_type is not None and MakeSmartType(value) != expected_type:
|
||||
return f"Invalid type of {key}: {value} (expected {expected_type})"
|
||||
return True
|
||||
setattr(cls, "VALIDATE_INPUTS", validate_inputs)
|
||||
return cls
|
||||
return decorator
|
||||
|
364
tests/execution/testing_nodes/testing-pack/util.py
Normal file
364
tests/execution/testing_nodes/testing-pack/util.py
Normal file
@@ -0,0 +1,364 @@
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from .tools import VariantSupport
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulateNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"to_add": ("*",),
|
||||
},
|
||||
"optional": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ACCUMULATION",)
|
||||
FUNCTION = "accumulate"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def accumulate(self, to_add, accumulation = None):
|
||||
if accumulation is None:
|
||||
value = [to_add]
|
||||
else:
|
||||
value = accumulation["accum"] + [to_add]
|
||||
return ({"accum": value},)
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationHeadNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ACCUMULATION", "*",)
|
||||
FUNCTION = "accumulation_head"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def accumulation_head(self, accumulation):
|
||||
accum = accumulation["accum"]
|
||||
if len(accum) == 0:
|
||||
return (accumulation, None)
|
||||
else:
|
||||
return ({"accum": accum[1:]}, accum[0])
|
||||
|
||||
class TestAccumulationTailNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ACCUMULATION", "*",)
|
||||
FUNCTION = "accumulation_tail"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def accumulation_tail(self, accumulation):
|
||||
accum = accumulation["accum"]
|
||||
if len(accum) == 0:
|
||||
return (None, accumulation)
|
||||
else:
|
||||
return ({"accum": accum[:-1]}, accum[-1])
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationToListNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
OUTPUT_IS_LIST = (True,)
|
||||
|
||||
FUNCTION = "accumulation_to_list"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def accumulation_to_list(self, accumulation):
|
||||
return (accumulation["accum"],)
|
||||
|
||||
@VariantSupport()
|
||||
class TestListToAccumulationNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"list": ("*",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ACCUMULATION",)
|
||||
INPUT_IS_LIST = (True,)
|
||||
|
||||
FUNCTION = "list_to_accumulation"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def list_to_accumulation(self, list):
|
||||
return ({"accum": list},)
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationGetLengthNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT",)
|
||||
|
||||
FUNCTION = "accumlength"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def accumlength(self, accumulation):
|
||||
return (len(accumulation['accum']),)
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationGetItemNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
"index": ("INT", {"default":0, "step":1})
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
|
||||
FUNCTION = "get_item"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def get_item(self, accumulation, index):
|
||||
return (accumulation['accum'][index],)
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationSetItemNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
"index": ("INT", {"default":0, "step":1}),
|
||||
"value": ("*",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ACCUMULATION",)
|
||||
|
||||
FUNCTION = "set_item"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def set_item(self, accumulation, index, value):
|
||||
new_accum = accumulation['accum'][:]
|
||||
new_accum[index] = value
|
||||
return ({"accum": new_accum},)
|
||||
|
||||
class TestIntMathOperation:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
|
||||
"b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
|
||||
"operation": (["add", "subtract", "multiply", "divide", "modulo", "power"],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "int_math_operation"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def int_math_operation(self, a, b, operation):
|
||||
if operation == "add":
|
||||
return (a + b,)
|
||||
elif operation == "subtract":
|
||||
return (a - b,)
|
||||
elif operation == "multiply":
|
||||
return (a * b,)
|
||||
elif operation == "divide":
|
||||
return (a // b,)
|
||||
elif operation == "modulo":
|
||||
return (a % b,)
|
||||
elif operation == "power":
|
||||
return (a ** b,)
|
||||
|
||||
|
||||
from .flow_control import NUM_FLOW_SOCKETS
|
||||
@VariantSupport()
|
||||
class TestForLoopOpen:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"remaining": ("INT", {"default": 1, "min": 0, "max": 100000, "step": 1}),
|
||||
},
|
||||
"optional": {
|
||||
f"initial_value{i}": ("*",) for i in range(1, NUM_FLOW_SOCKETS)
|
||||
},
|
||||
"hidden": {
|
||||
"initial_value0": ("*",)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = tuple(["FLOW_CONTROL", "INT",] + ["*"] * (NUM_FLOW_SOCKETS-1))
|
||||
RETURN_NAMES = tuple(["flow_control", "remaining"] + [f"value{i}" for i in range(1, NUM_FLOW_SOCKETS)])
|
||||
FUNCTION = "for_loop_open"
|
||||
|
||||
CATEGORY = "Testing/Flow"
|
||||
|
||||
def for_loop_open(self, remaining, **kwargs):
|
||||
graph = GraphBuilder()
|
||||
if "initial_value0" in kwargs:
|
||||
remaining = kwargs["initial_value0"]
|
||||
graph.node("TestWhileLoopOpen", condition=remaining, initial_value0=remaining, **{(f"initial_value{i}"): kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)})
|
||||
outputs = [kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)]
|
||||
return {
|
||||
"result": tuple(["stub", remaining] + outputs),
|
||||
"expand": graph.finalize(),
|
||||
}
|
||||
|
||||
@VariantSupport()
|
||||
class TestForLoopClose:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"flow_control": ("FLOW_CONTROL", {"rawLink": True}),
|
||||
},
|
||||
"optional": {
|
||||
f"initial_value{i}": ("*",{"rawLink": True}) for i in range(1, NUM_FLOW_SOCKETS)
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = tuple(["*"] * (NUM_FLOW_SOCKETS-1))
|
||||
RETURN_NAMES = tuple([f"value{i}" for i in range(1, NUM_FLOW_SOCKETS)])
|
||||
FUNCTION = "for_loop_close"
|
||||
|
||||
CATEGORY = "Testing/Flow"
|
||||
|
||||
def for_loop_close(self, flow_control, **kwargs):
|
||||
graph = GraphBuilder()
|
||||
while_open = flow_control[0]
|
||||
sub = graph.node("TestIntMathOperation", operation="subtract", a=[while_open,1], b=1)
|
||||
cond = graph.node("TestToBoolNode", value=sub.out(0))
|
||||
input_values = {f"initial_value{i}": kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)}
|
||||
while_close = graph.node("TestWhileLoopClose",
|
||||
flow_control=flow_control,
|
||||
condition=cond.out(0),
|
||||
initial_value0=sub.out(0),
|
||||
**input_values)
|
||||
return {
|
||||
"result": tuple([while_close.out(i) for i in range(1, NUM_FLOW_SOCKETS)]),
|
||||
"expand": graph.finalize(),
|
||||
}
|
||||
|
||||
NUM_LIST_SOCKETS = 10
|
||||
@VariantSupport()
|
||||
class TestMakeListNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value1": ("*",),
|
||||
},
|
||||
"optional": {
|
||||
f"value{i}": ("*",) for i in range(1, NUM_LIST_SOCKETS)
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
FUNCTION = "make_list"
|
||||
OUTPUT_IS_LIST = (True,)
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def make_list(self, **kwargs):
|
||||
result = []
|
||||
for i in range(NUM_LIST_SOCKETS):
|
||||
if f"value{i}" in kwargs:
|
||||
result.append(kwargs[f"value{i}"])
|
||||
return (result,)
|
||||
|
||||
UTILITY_NODE_CLASS_MAPPINGS = {
|
||||
"TestAccumulateNode": TestAccumulateNode,
|
||||
"TestAccumulationHeadNode": TestAccumulationHeadNode,
|
||||
"TestAccumulationTailNode": TestAccumulationTailNode,
|
||||
"TestAccumulationToListNode": TestAccumulationToListNode,
|
||||
"TestListToAccumulationNode": TestListToAccumulationNode,
|
||||
"TestAccumulationGetLengthNode": TestAccumulationGetLengthNode,
|
||||
"TestAccumulationGetItemNode": TestAccumulationGetItemNode,
|
||||
"TestAccumulationSetItemNode": TestAccumulationSetItemNode,
|
||||
"TestForLoopOpen": TestForLoopOpen,
|
||||
"TestForLoopClose": TestForLoopClose,
|
||||
"TestIntMathOperation": TestIntMathOperation,
|
||||
"TestMakeListNode": TestMakeListNode,
|
||||
}
|
||||
UTILITY_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestAccumulateNode": "Accumulate",
|
||||
"TestAccumulationHeadNode": "Accumulation Head",
|
||||
"TestAccumulationTailNode": "Accumulation Tail",
|
||||
"TestAccumulationToListNode": "Accumulation to List",
|
||||
"TestListToAccumulationNode": "List to Accumulation",
|
||||
"TestAccumulationGetLengthNode": "Accumulation Get Length",
|
||||
"TestAccumulationGetItemNode": "Accumulation Get Item",
|
||||
"TestAccumulationSetItemNode": "Accumulation Set Item",
|
||||
"TestForLoopOpen": "For Loop Open",
|
||||
"TestForLoopClose": "For Loop Close",
|
||||
"TestIntMathOperation": "Int Math Operation",
|
||||
"TestMakeListNode": "Make List",
|
||||
}
|
Reference in New Issue
Block a user