Fix progress update crossover between users (#9706)

* Fix showing progress from other sessions

Because `client_id` was missing from ths `progress_state` message, it
was being sent to all connected sessions. This technically meant that if
someone had a graph with the same nodes, they would see the progress
updates for others.

Also added a test to prevent reoccurance and moved the tests around to
make CI easier to hook up.

* Fix CI issues related to timing-sensitive tests
This commit is contained in:
guill
2025-09-04 16:13:28 -07:00
committed by GitHub
parent b0338e930b
commit a9f1bb10a5
16 changed files with 295 additions and 18 deletions

View File

@@ -0,0 +1,28 @@
from .specific_tests import TEST_NODE_CLASS_MAPPINGS, TEST_NODE_DISPLAY_NAME_MAPPINGS
from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS
from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS
from .api_test_nodes import API_TEST_NODE_CLASS_MAPPINGS, API_TEST_NODE_DISPLAY_NAME_MAPPINGS
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
NODE_CLASS_MAPPINGS = {}
NODE_CLASS_MAPPINGS.update(TEST_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(API_TEST_NODE_CLASS_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(API_TEST_NODE_DISPLAY_NAME_MAPPINGS)

View File

@@ -0,0 +1,78 @@
import asyncio
import time
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
from comfy_api.v0_0_2 import ComfyAPI, ComfyAPISync
api = ComfyAPI()
api_sync = ComfyAPISync()
class TestAsyncProgressUpdate(ComfyNodeABC):
"""Test node with async VALIDATE_INPUTS."""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"value": (IO.ANY, {}),
"sleep_seconds": (IO.FLOAT, {"default": 1.0}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "execute"
CATEGORY = "_for_testing/async"
async def execute(self, value, sleep_seconds):
start = time.time()
expiration = start + sleep_seconds
now = start
while now < expiration:
now = time.time()
await api.execution.set_progress(
value=(now - start) / sleep_seconds,
max_value=1.0,
)
await asyncio.sleep(0.01)
return (value,)
class TestSyncProgressUpdate(ComfyNodeABC):
"""Test node with async VALIDATE_INPUTS."""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"value": (IO.ANY, {}),
"sleep_seconds": (IO.FLOAT, {"default": 1.0}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "execute"
CATEGORY = "_for_testing/async"
def execute(self, value, sleep_seconds):
start = time.time()
expiration = start + sleep_seconds
now = start
while now < expiration:
now = time.time()
api_sync.execution.set_progress(
value=(now - start) / sleep_seconds,
max_value=1.0,
)
time.sleep(0.01)
return (value,)
API_TEST_NODE_CLASS_MAPPINGS = {
"TestAsyncProgressUpdate": TestAsyncProgressUpdate,
"TestSyncProgressUpdate": TestSyncProgressUpdate,
}
API_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestAsyncProgressUpdate": "Async Progress Update Test Node",
"TestSyncProgressUpdate": "Sync Progress Update Test Node",
}

View File

@@ -0,0 +1,343 @@
import torch
import asyncio
from typing import Dict
from comfy.utils import ProgressBar
from comfy_execution.graph_utils import GraphBuilder
from comfy.comfy_types.node_typing import ComfyNodeABC
from comfy.comfy_types import IO
class TestAsyncValidation(ComfyNodeABC):
"""Test node with async VALIDATE_INPUTS."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 5.0}),
"threshold": ("FLOAT", {"default": 10.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing/async"
@classmethod
async def VALIDATE_INPUTS(cls, value, threshold):
# Simulate async validation (e.g., checking remote service)
await asyncio.sleep(0.05)
if value > threshold:
return f"Value {value} exceeds threshold {threshold}"
return True
def process(self, value, threshold):
# Create image based on value
intensity = value / 10.0
image = torch.ones([1, 512, 512, 3]) * intensity
return (image,)
class TestAsyncError(ComfyNodeABC):
"""Test node that errors during async execution."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"error_after": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "error_execution"
CATEGORY = "_for_testing/async"
async def error_execution(self, value, error_after):
await asyncio.sleep(error_after)
raise RuntimeError("Intentional async execution error for testing")
class TestAsyncValidationError(ComfyNodeABC):
"""Test node with async validation that always fails."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 5.0}),
"max_value": ("FLOAT", {"default": 10.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing/async"
@classmethod
async def VALIDATE_INPUTS(cls, value, max_value):
await asyncio.sleep(0.05)
# Always fail validation for values > max_value
if value > max_value:
return f"Async validation failed: {value} > {max_value}"
return True
def process(self, value, max_value):
# This won't be reached if validation fails
image = torch.ones([1, 512, 512, 3]) * (value / max_value)
return (image,)
class TestAsyncTimeout(ComfyNodeABC):
"""Test node that simulates timeout scenarios."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"timeout": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}),
"operation_time": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "timeout_execution"
CATEGORY = "_for_testing/async"
async def timeout_execution(self, value, timeout, operation_time):
try:
# This will timeout if operation_time > timeout
await asyncio.wait_for(asyncio.sleep(operation_time), timeout=timeout)
return (value,)
except asyncio.TimeoutError:
raise RuntimeError(f"Operation timed out after {timeout} seconds")
class TestSyncError(ComfyNodeABC):
"""Test node that errors synchronously (for mixed sync/async testing)."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "sync_error"
CATEGORY = "_for_testing/async"
def sync_error(self, value):
raise RuntimeError("Intentional sync execution error for testing")
class TestAsyncLazyCheck(ComfyNodeABC):
"""Test node with async check_lazy_status."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": (IO.ANY, {"lazy": True}),
"input2": (IO.ANY, {"lazy": True}),
"condition": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing/async"
async def check_lazy_status(self, condition, input1, input2):
# Simulate async checking (e.g., querying remote service)
await asyncio.sleep(0.05)
needed = []
if condition and input1 is None:
needed.append("input1")
if not condition and input2 is None:
needed.append("input2")
return needed
def process(self, input1, input2, condition):
# Return a simple image
return (torch.ones([1, 512, 512, 3]),)
class TestDynamicAsyncGeneration(ComfyNodeABC):
"""Test node that dynamically generates async nodes."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image1": ("IMAGE",),
"image2": ("IMAGE",),
"num_async_nodes": ("INT", {"default": 3, "min": 1, "max": 10}),
"sleep_duration": ("FLOAT", {"default": 0.2, "min": 0.1, "max": 1.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "generate_async_workflow"
CATEGORY = "_for_testing/async"
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
g = GraphBuilder()
# Create multiple async sleep nodes
sleep_nodes = []
for i in range(num_async_nodes):
image = image1 if i % 2 == 0 else image2
sleep_node = g.node("TestSleep", value=image, seconds=sleep_duration)
sleep_nodes.append(sleep_node)
# Average all results
if len(sleep_nodes) == 1:
final_node = sleep_nodes[0]
else:
avg_inputs = {"input1": sleep_nodes[0].out(0)}
for i, node in enumerate(sleep_nodes[1:], 2):
avg_inputs[f"input{i}"] = node.out(0)
final_node = g.node("TestVariadicAverage", **avg_inputs)
return {
"result": (final_node.out(0),),
"expand": g.finalize(),
}
class TestAsyncResourceUser(ComfyNodeABC):
"""Test node that uses resources during async execution."""
# Class-level resource tracking for testing
_active_resources: Dict[str, bool] = {}
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"resource_id": ("STRING", {"default": "resource_0"}),
"duration": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "use_resource"
CATEGORY = "_for_testing/async"
async def use_resource(self, value, resource_id, duration):
# Check if resource is already in use
if self._active_resources.get(resource_id, False):
raise RuntimeError(f"Resource {resource_id} is already in use!")
# Mark resource as in use
self._active_resources[resource_id] = True
try:
# Simulate resource usage
await asyncio.sleep(duration)
return (value,)
finally:
# Always clean up resource
self._active_resources[resource_id] = False
class TestAsyncBatchProcessing(ComfyNodeABC):
"""Test async processing of batched inputs."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"process_time_per_item": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 1.0}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process_batch"
CATEGORY = "_for_testing/async"
async def process_batch(self, images, process_time_per_item, unique_id):
batch_size = images.shape[0]
pbar = ProgressBar(batch_size, node_id=unique_id)
# Process each image in the batch
processed = []
for i in range(batch_size):
# Simulate async processing
await asyncio.sleep(process_time_per_item)
# Simple processing: invert the image
processed_image = 1.0 - images[i:i+1]
processed.append(processed_image)
pbar.update(1)
# Stack processed images
result = torch.cat(processed, dim=0)
return (result,)
class TestAsyncConcurrentLimit(ComfyNodeABC):
"""Test concurrent execution limits for async nodes."""
_semaphore = asyncio.Semaphore(2) # Only allow 2 concurrent executions
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"duration": ("FLOAT", {"default": 0.5, "min": 0.1, "max": 2.0}),
"node_id": ("INT", {"default": 0}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "limited_execution"
CATEGORY = "_for_testing/async"
async def limited_execution(self, value, duration, node_id):
async with self._semaphore:
# Node {node_id} acquired semaphore
await asyncio.sleep(duration)
# Node {node_id} releasing semaphore
return (value,)
# Add node mappings
ASYNC_TEST_NODE_CLASS_MAPPINGS = {
"TestAsyncValidation": TestAsyncValidation,
"TestAsyncError": TestAsyncError,
"TestAsyncValidationError": TestAsyncValidationError,
"TestAsyncTimeout": TestAsyncTimeout,
"TestSyncError": TestSyncError,
"TestAsyncLazyCheck": TestAsyncLazyCheck,
"TestDynamicAsyncGeneration": TestDynamicAsyncGeneration,
"TestAsyncResourceUser": TestAsyncResourceUser,
"TestAsyncBatchProcessing": TestAsyncBatchProcessing,
"TestAsyncConcurrentLimit": TestAsyncConcurrentLimit,
}
ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestAsyncValidation": "Test Async Validation",
"TestAsyncError": "Test Async Error",
"TestAsyncValidationError": "Test Async Validation Error",
"TestAsyncTimeout": "Test Async Timeout",
"TestSyncError": "Test Sync Error",
"TestAsyncLazyCheck": "Test Async Lazy Check",
"TestDynamicAsyncGeneration": "Test Dynamic Async Generation",
"TestAsyncResourceUser": "Test Async Resource User",
"TestAsyncBatchProcessing": "Test Async Batch Processing",
"TestAsyncConcurrentLimit": "Test Async Concurrent Limit",
}

View File

@@ -0,0 +1,194 @@
import re
import torch
class TestIntConditions:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
"b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
"operation": (["==", "!=", "<", ">", "<=", ">="],),
},
}
RETURN_TYPES = ("BOOLEAN",)
FUNCTION = "int_condition"
CATEGORY = "Testing/Logic"
def int_condition(self, a, b, operation):
if operation == "==":
return (a == b,)
elif operation == "!=":
return (a != b,)
elif operation == "<":
return (a < b,)
elif operation == ">":
return (a > b,)
elif operation == "<=":
return (a <= b,)
elif operation == ">=":
return (a >= b,)
class TestFloatConditions:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"a": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}),
"b": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}),
"operation": (["==", "!=", "<", ">", "<=", ">="],),
},
}
RETURN_TYPES = ("BOOLEAN",)
FUNCTION = "float_condition"
CATEGORY = "Testing/Logic"
def float_condition(self, a, b, operation):
if operation == "==":
return (a == b,)
elif operation == "!=":
return (a != b,)
elif operation == "<":
return (a < b,)
elif operation == ">":
return (a > b,)
elif operation == "<=":
return (a <= b,)
elif operation == ">=":
return (a >= b,)
class TestStringConditions:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"a": ("STRING", {"multiline": False}),
"b": ("STRING", {"multiline": False}),
"operation": (["a == b", "a != b", "a IN b", "a MATCH REGEX(b)", "a BEGINSWITH b", "a ENDSWITH b"],),
"case_sensitive": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ("BOOLEAN",)
FUNCTION = "string_condition"
CATEGORY = "Testing/Logic"
def string_condition(self, a, b, operation, case_sensitive):
if not case_sensitive:
a = a.lower()
b = b.lower()
if operation == "a == b":
return (a == b,)
elif operation == "a != b":
return (a != b,)
elif operation == "a IN b":
return (a in b,)
elif operation == "a MATCH REGEX(b)":
try:
return (re.match(b, a) is not None,)
except:
return (False,)
elif operation == "a BEGINSWITH b":
return (a.startswith(b),)
elif operation == "a ENDSWITH b":
return (a.endswith(b),)
class TestToBoolNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("*",),
},
"optional": {
"invert": ("BOOLEAN", {"default": False}),
},
}
RETURN_TYPES = ("BOOLEAN",)
FUNCTION = "to_bool"
CATEGORY = "Testing/Logic"
def to_bool(self, value, invert = False):
if isinstance(value, torch.Tensor):
if value.max().item() == 0 and value.min().item() == 0:
result = False
else:
result = True
else:
try:
result = bool(value)
except:
# Can't convert it? Well then it's something or other. I dunno, I'm not a Python programmer.
result = True
if invert:
result = not result
return (result,)
class TestBoolOperationNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"a": ("BOOLEAN",),
"b": ("BOOLEAN",),
"op": (["a AND b", "a OR b", "a XOR b", "NOT a"],),
},
}
RETURN_TYPES = ("BOOLEAN",)
FUNCTION = "bool_operation"
CATEGORY = "Testing/Logic"
def bool_operation(self, a, b, op):
if op == "a AND b":
return (a and b,)
elif op == "a OR b":
return (a or b,)
elif op == "a XOR b":
return (a ^ b,)
elif op == "NOT a":
return (not a,)
CONDITION_NODE_CLASS_MAPPINGS = {
"TestIntConditions": TestIntConditions,
"TestFloatConditions": TestFloatConditions,
"TestStringConditions": TestStringConditions,
"TestToBoolNode": TestToBoolNode,
"TestBoolOperationNode": TestBoolOperationNode,
}
CONDITION_NODE_DISPLAY_NAME_MAPPINGS = {
"TestIntConditions": "Int Condition",
"TestFloatConditions": "Float Condition",
"TestStringConditions": "String Condition",
"TestToBoolNode": "To Bool",
"TestBoolOperationNode": "Bool Operation",
}

View File

@@ -0,0 +1,173 @@
from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.graph import ExecutionBlocker
from .tools import VariantSupport
NUM_FLOW_SOCKETS = 5
@VariantSupport()
class TestWhileLoopOpen:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
inputs = {
"required": {
"condition": ("BOOLEAN", {"default": True}),
},
"optional": {
},
}
for i in range(NUM_FLOW_SOCKETS):
inputs["optional"][f"initial_value{i}"] = ("*",)
return inputs
RETURN_TYPES = tuple(["FLOW_CONTROL"] + ["*"] * NUM_FLOW_SOCKETS)
RETURN_NAMES = tuple(["FLOW_CONTROL"] + [f"value{i}" for i in range(NUM_FLOW_SOCKETS)])
FUNCTION = "while_loop_open"
CATEGORY = "Testing/Flow"
def while_loop_open(self, condition, **kwargs):
values = []
for i in range(NUM_FLOW_SOCKETS):
values.append(kwargs.get(f"initial_value{i}", None))
return tuple(["stub"] + values)
@VariantSupport()
class TestWhileLoopClose:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
inputs = {
"required": {
"flow_control": ("FLOW_CONTROL", {"rawLink": True}),
"condition": ("BOOLEAN", {"forceInput": True}),
},
"optional": {
},
"hidden": {
"dynprompt": "DYNPROMPT",
"unique_id": "UNIQUE_ID",
}
}
for i in range(NUM_FLOW_SOCKETS):
inputs["optional"][f"initial_value{i}"] = ("*",)
return inputs
RETURN_TYPES = tuple(["*"] * NUM_FLOW_SOCKETS)
RETURN_NAMES = tuple([f"value{i}" for i in range(NUM_FLOW_SOCKETS)])
FUNCTION = "while_loop_close"
CATEGORY = "Testing/Flow"
def explore_dependencies(self, node_id, dynprompt, upstream):
node_info = dynprompt.get_node(node_id)
if "inputs" not in node_info:
return
for k, v in node_info["inputs"].items():
if is_link(v):
parent_id = v[0]
if parent_id not in upstream:
upstream[parent_id] = []
self.explore_dependencies(parent_id, dynprompt, upstream)
upstream[parent_id].append(node_id)
def collect_contained(self, node_id, upstream, contained):
if node_id not in upstream:
return
for child_id in upstream[node_id]:
if child_id not in contained:
contained[child_id] = True
self.collect_contained(child_id, upstream, contained)
def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=None, **kwargs):
assert dynprompt is not None
if not condition:
# We're done with the loop
values = []
for i in range(NUM_FLOW_SOCKETS):
values.append(kwargs.get(f"initial_value{i}", None))
return tuple(values)
# We want to loop
upstream = {}
# Get the list of all nodes between the open and close nodes
self.explore_dependencies(unique_id, dynprompt, upstream)
contained = {}
open_node = flow_control[0]
self.collect_contained(open_node, upstream, contained)
contained[unique_id] = True
contained[open_node] = True
# We'll use the default prefix, but to avoid having node names grow exponentially in size,
# we'll use "Recurse" for the name of the recursively-generated copy of this node.
graph = GraphBuilder()
for node_id in contained:
original_node = dynprompt.get_node(node_id)
node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id)
node.set_override_display_id(node_id)
for node_id in contained:
original_node = dynprompt.get_node(node_id)
node = graph.lookup_node("Recurse" if node_id == unique_id else node_id)
assert node is not None
for k, v in original_node["inputs"].items():
if is_link(v) and v[0] in contained:
parent = graph.lookup_node(v[0])
assert parent is not None
node.set_input(k, parent.out(v[1]))
else:
node.set_input(k, v)
new_open = graph.lookup_node(open_node)
assert new_open is not None
for i in range(NUM_FLOW_SOCKETS):
key = f"initial_value{i}"
new_open.set_input(key, kwargs.get(key, None))
my_clone = graph.lookup_node("Recurse")
assert my_clone is not None
result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS))
return {
"result": tuple(result),
"expand": graph.finalize(),
}
@VariantSupport()
class TestExecutionBlockerNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
inputs = {
"required": {
"input": ("*",),
"block": ("BOOLEAN",),
"verbose": ("BOOLEAN", {"default": False}),
},
}
return inputs
RETURN_TYPES = ("*",)
RETURN_NAMES = ("output",)
FUNCTION = "execution_blocker"
CATEGORY = "Testing/Flow"
def execution_blocker(self, input, block, verbose):
if block:
return (ExecutionBlocker("Blocked Execution" if verbose else None),)
return (input,)
FLOW_CONTROL_NODE_CLASS_MAPPINGS = {
"TestWhileLoopOpen": TestWhileLoopOpen,
"TestWhileLoopClose": TestWhileLoopClose,
"TestExecutionBlocker": TestExecutionBlockerNode,
}
FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = {
"TestWhileLoopOpen": "While Loop Open",
"TestWhileLoopClose": "While Loop Close",
"TestExecutionBlocker": "Execution Blocker",
}

View File

@@ -0,0 +1,519 @@
import torch
import time
import asyncio
from comfy.utils import ProgressBar
from .tools import VariantSupport
from comfy_execution.graph_utils import GraphBuilder
from comfy.comfy_types.node_typing import ComfyNodeABC
from comfy.comfy_types import IO
class TestLazyMixImages:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image1": ("IMAGE",{"lazy": True}),
"image2": ("IMAGE",{"lazy": True}),
"mask": ("MASK",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "mix"
CATEGORY = "Testing/Nodes"
def check_lazy_status(self, mask, image1, image2):
mask_min = mask.min()
mask_max = mask.max()
needed = []
if image1 is None and (mask_min != 1.0 or mask_max != 1.0):
needed.append("image1")
if image2 is None and (mask_min != 0.0 or mask_max != 0.0):
needed.append("image2")
return needed
# Not trying to handle different batch sizes here just to keep the demo simple
def mix(self, mask, image1, image2):
mask_min = mask.min()
mask_max = mask.max()
if mask_min == 0.0 and mask_max == 0.0:
return (image1,)
elif mask_min == 1.0 and mask_max == 1.0:
return (image2,)
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
if len(mask.shape) == 3:
mask = mask.unsqueeze(3)
if mask.shape[3] < image1.shape[3]:
mask = mask.repeat(1, 1, 1, image1.shape[3])
result = image1 * (1. - mask) + image2 * mask,
return (result[0],)
class TestVariadicAverage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "variadic_average"
CATEGORY = "Testing/Nodes"
def variadic_average(self, input1, **kwargs):
inputs = [input1]
while 'input' + str(len(inputs) + 1) in kwargs:
inputs.append(kwargs['input' + str(len(inputs) + 1)])
return (torch.stack(inputs).mean(dim=0),)
class TestCustomIsChanged:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
"optional": {
"should_change": ("BOOL", {"default": False}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_is_changed"
CATEGORY = "Testing/Nodes"
def custom_is_changed(self, image, should_change=False):
return (image,)
@classmethod
def IS_CHANGED(cls, should_change=False, *args, **kwargs):
if should_change:
return float("NaN")
else:
return False
class TestIsChangedWithConstants:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_is_changed"
CATEGORY = "Testing/Nodes"
def custom_is_changed(self, image, value):
return (image * value,)
@classmethod
def IS_CHANGED(cls, image, value):
if image is None:
return value
else:
return image.mean().item() * value
class TestCustomValidation1:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE,FLOAT",),
"input2": ("IMAGE,FLOAT",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation1"
CATEGORY = "Testing/Nodes"
def custom_validation1(self, input1, input2):
if isinstance(input1, float) and isinstance(input2, float):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
else:
result = input1 * input2
return (result,)
@classmethod
def VALIDATE_INPUTS(cls, input1=None, input2=None):
if input1 is not None:
if not isinstance(input1, (torch.Tensor, float)):
return f"Invalid type of input1: {type(input1)}"
if input2 is not None:
if not isinstance(input2, (torch.Tensor, float)):
return f"Invalid type of input2: {type(input2)}"
return True
class TestCustomValidation2:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE,FLOAT",),
"input2": ("IMAGE,FLOAT",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation2"
CATEGORY = "Testing/Nodes"
def custom_validation2(self, input1, input2):
if isinstance(input1, float) and isinstance(input2, float):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
else:
result = input1 * input2
return (result,)
@classmethod
def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None):
if input1 is not None:
if not isinstance(input1, (torch.Tensor, float)):
return f"Invalid type of input1: {type(input1)}"
if input2 is not None:
if not isinstance(input2, (torch.Tensor, float)):
return f"Invalid type of input2: {type(input2)}"
if 'input1' in input_types:
if input_types['input1'] not in ["IMAGE", "FLOAT"]:
return f"Invalid type of input1: {input_types['input1']}"
if 'input2' in input_types:
if input_types['input2'] not in ["IMAGE", "FLOAT"]:
return f"Invalid type of input2: {input_types['input2']}"
return True
@VariantSupport()
class TestCustomValidation3:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE,FLOAT",),
"input2": ("IMAGE,FLOAT",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation3"
CATEGORY = "Testing/Nodes"
def custom_validation3(self, input1, input2):
if isinstance(input1, float) and isinstance(input2, float):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
else:
result = input1 * input2
return (result,)
class TestCustomValidation4:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("FLOAT",),
"input2": ("FLOAT",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation4"
CATEGORY = "Testing/Nodes"
def custom_validation4(self, input1, input2):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
return (result,)
@classmethod
def VALIDATE_INPUTS(cls, input1, input2):
if input1 is not None:
if not isinstance(input1, float):
return f"Invalid type of input1: {type(input1)}"
if input2 is not None:
if not isinstance(input2, float):
return f"Invalid type of input2: {type(input2)}"
return True
class TestCustomValidation5:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("FLOAT", {"min": 0.0, "max": 1.0}),
"input2": ("FLOAT", {"min": 0.0, "max": 1.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation5"
CATEGORY = "Testing/Nodes"
def custom_validation5(self, input1, input2):
value = input1 * input2
return (torch.ones([1, 512, 512, 3]) * value,)
@classmethod
def VALIDATE_INPUTS(cls, **kwargs):
if kwargs['input2'] == 7.0:
return "7s are not allowed. I've never liked 7s."
return True
class TestDynamicDependencyCycle:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE",),
"input2": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "dynamic_dependency_cycle"
CATEGORY = "Testing/Nodes"
def dynamic_dependency_cycle(self, input1, input2):
g = GraphBuilder()
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
mix1 = g.node("TestLazyMixImages", image1=input1, mask=mask.out(0))
mix2 = g.node("TestLazyMixImages", image1=mix1.out(0), image2=input2, mask=mask.out(0))
# Create the cyle
mix1.set_input("image2", mix2.out(0))
return {
"result": (mix2.out(0),),
"expand": g.finalize(),
}
class TestMixedExpansionReturns:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("FLOAT",),
},
}
RETURN_TYPES = ("IMAGE","IMAGE")
FUNCTION = "mixed_expansion_returns"
CATEGORY = "Testing/Nodes"
def mixed_expansion_returns(self, input1):
white_image = torch.ones([1, 512, 512, 3])
if input1 <= 0.1:
return (torch.ones([1, 512, 512, 3]) * 0.1, white_image)
elif input1 <= 0.2:
return {
"result": (torch.ones([1, 512, 512, 3]) * 0.2, white_image),
}
else:
g = GraphBuilder()
mask = g.node("StubMask", value=0.3, height=512, width=512, batch_size=1)
black = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
white = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
mix = g.node("TestLazyMixImages", image1=black.out(0), image2=white.out(0), mask=mask.out(0))
return {
"result": (mix.out(0), white_image),
"expand": g.finalize(),
}
class TestSamplingInExpansion:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"clip": ("CLIP",),
"vae": ("VAE",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 100}),
"cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 30.0}),
"prompt": ("STRING", {"multiline": True, "default": "a beautiful landscape with mountains and trees"}),
"negative_prompt": ("STRING", {"multiline": True, "default": "blurry, bad quality, worst quality"}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "sampling_in_expansion"
CATEGORY = "Testing/Nodes"
def sampling_in_expansion(self, model, clip, vae, seed, steps, cfg, prompt, negative_prompt):
g = GraphBuilder()
# Create a basic image generation workflow using the input model, clip and vae
# 1. Setup text prompts using the provided CLIP model
positive_prompt = g.node("CLIPTextEncode",
text=prompt,
clip=clip)
negative_prompt = g.node("CLIPTextEncode",
text=negative_prompt,
clip=clip)
# 2. Create empty latent with specified size
empty_latent = g.node("EmptyLatentImage", width=512, height=512, batch_size=1)
# 3. Setup sampler and generate image latent
sampler = g.node("KSampler",
model=model,
positive=positive_prompt.out(0),
negative=negative_prompt.out(0),
latent_image=empty_latent.out(0),
seed=seed,
steps=steps,
cfg=cfg,
sampler_name="euler_ancestral",
scheduler="normal")
# 4. Decode latent to image using VAE
output = g.node("VAEDecode", samples=sampler.out(0), vae=vae)
return {
"result": (output.out(0),),
"expand": g.finalize(),
}
class TestSleep(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "sleep"
CATEGORY = "_for_testing"
async def sleep(self, value, seconds, unique_id):
pbar = ProgressBar(seconds, node_id=unique_id)
start = time.time()
expiration = start + seconds
now = start
while now < expiration:
now = time.time()
pbar.update_absolute(now - start)
await asyncio.sleep(0.01)
return (value,)
class TestParallelSleep(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image1": ("IMAGE", ),
"image2": ("IMAGE", ),
"image3": ("IMAGE", ),
"sleep1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
"sleep2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
"sleep3": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "parallel_sleep"
CATEGORY = "_for_testing"
OUTPUT_NODE = True
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
# Create a graph dynamically with three TestSleep nodes
g = GraphBuilder()
# Create sleep nodes for each duration and image
sleep_node1 = g.node("TestSleep", value=image1, seconds=sleep1)
sleep_node2 = g.node("TestSleep", value=image2, seconds=sleep2)
sleep_node3 = g.node("TestSleep", value=image3, seconds=sleep3)
# Blend the results using TestVariadicAverage
blend = g.node("TestVariadicAverage",
input1=sleep_node1.out(0),
input2=sleep_node2.out(0),
input3=sleep_node3.out(0))
return {
"result": (blend.out(0),),
"expand": g.finalize(),
}
class TestOutputNodeWithSocketOutput:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing"
OUTPUT_NODE = True
def process(self, image, value):
# Apply value scaling and return both as output and socket
result = image * value
return (result,)
TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage,
"TestCustomIsChanged": TestCustomIsChanged,
"TestIsChangedWithConstants": TestIsChangedWithConstants,
"TestCustomValidation1": TestCustomValidation1,
"TestCustomValidation2": TestCustomValidation2,
"TestCustomValidation3": TestCustomValidation3,
"TestCustomValidation4": TestCustomValidation4,
"TestCustomValidation5": TestCustomValidation5,
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
"TestMixedExpansionReturns": TestMixedExpansionReturns,
"TestSamplingInExpansion": TestSamplingInExpansion,
"TestSleep": TestSleep,
"TestParallelSleep": TestParallelSleep,
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
}
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestLazyMixImages": "Lazy Mix Images",
"TestVariadicAverage": "Variadic Average",
"TestCustomIsChanged": "Custom IsChanged",
"TestIsChangedWithConstants": "IsChanged With Constants",
"TestCustomValidation1": "Custom Validation 1",
"TestCustomValidation2": "Custom Validation 2",
"TestCustomValidation3": "Custom Validation 3",
"TestCustomValidation4": "Custom Validation 4",
"TestCustomValidation5": "Custom Validation 5",
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
"TestMixedExpansionReturns": "Mixed Expansion Returns",
"TestSamplingInExpansion": "Sampling In Expansion",
"TestSleep": "Test Sleep",
"TestParallelSleep": "Test Parallel Sleep",
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
}

View File

@@ -0,0 +1,129 @@
import torch
class StubImage:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"content": (['WHITE', 'BLACK', 'NOISE'],),
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_image"
CATEGORY = "Testing/Stub Nodes"
def stub_image(self, content, height, width, batch_size):
if content == "WHITE":
return (torch.ones(batch_size, height, width, 3),)
elif content == "BLACK":
return (torch.zeros(batch_size, height, width, 3),)
elif content == "NOISE":
return (torch.rand(batch_size, height, width, 3),)
class StubConstantImage:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_constant_image"
CATEGORY = "Testing/Stub Nodes"
def stub_constant_image(self, value, height, width, batch_size):
return (torch.ones(batch_size, height, width, 3) * value,)
class StubMask:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
},
}
RETURN_TYPES = ("MASK",)
FUNCTION = "stub_mask"
CATEGORY = "Testing/Stub Nodes"
def stub_mask(self, value, height, width, batch_size):
return (torch.ones(batch_size, height, width) * value,)
class StubInt:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("INT", {"default": 0, "min": -0xffffffff, "max": 0xffffffff, "step": 1}),
},
}
RETURN_TYPES = ("INT",)
FUNCTION = "stub_int"
CATEGORY = "Testing/Stub Nodes"
def stub_int(self, value):
return (value,)
class StubFloat:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 0.0, "min": -1.0e38, "max": 1.0e38, "step": 0.01}),
},
}
RETURN_TYPES = ("FLOAT",)
FUNCTION = "stub_float"
CATEGORY = "Testing/Stub Nodes"
def stub_float(self, value):
return (value,)
TEST_STUB_NODE_CLASS_MAPPINGS = {
"StubImage": StubImage,
"StubConstantImage": StubConstantImage,
"StubMask": StubMask,
"StubInt": StubInt,
"StubFloat": StubFloat,
}
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
"StubImage": "Stub Image",
"StubConstantImage": "Stub Constant Image",
"StubMask": "Stub Mask",
"StubInt": "Stub Int",
"StubFloat": "Stub Float",
}

View File

@@ -0,0 +1,53 @@
def MakeSmartType(t):
if isinstance(t, str):
return SmartType(t)
return t
class SmartType(str):
def __ne__(self, other):
if self == "*" or other == "*":
return False
selfset = set(self.split(','))
otherset = set(other.split(','))
return not selfset.issubset(otherset)
def VariantSupport():
def decorator(cls):
if hasattr(cls, "INPUT_TYPES"):
old_input_types = getattr(cls, "INPUT_TYPES")
def new_input_types(*args, **kwargs):
types = old_input_types(*args, **kwargs)
for category in ["required", "optional"]:
if category not in types:
continue
for key, value in types[category].items():
if isinstance(value, tuple):
types[category][key] = (MakeSmartType(value[0]),) + value[1:]
return types
setattr(cls, "INPUT_TYPES", new_input_types)
if hasattr(cls, "RETURN_TYPES"):
old_return_types = cls.RETURN_TYPES
setattr(cls, "RETURN_TYPES", tuple(MakeSmartType(x) for x in old_return_types))
if hasattr(cls, "VALIDATE_INPUTS"):
# Reflection is used to determine what the function signature is, so we can't just change the function signature
raise NotImplementedError("VariantSupport does not support VALIDATE_INPUTS yet")
else:
def validate_inputs(input_types):
inputs = cls.INPUT_TYPES()
for key, value in input_types.items():
if isinstance(value, SmartType):
continue
if "required" in inputs and key in inputs["required"]:
expected_type = inputs["required"][key][0]
elif "optional" in inputs and key in inputs["optional"]:
expected_type = inputs["optional"][key][0]
else:
expected_type = None
if expected_type is not None and MakeSmartType(value) != expected_type:
return f"Invalid type of {key}: {value} (expected {expected_type})"
return True
setattr(cls, "VALIDATE_INPUTS", validate_inputs)
return cls
return decorator

View File

@@ -0,0 +1,364 @@
from comfy_execution.graph_utils import GraphBuilder
from .tools import VariantSupport
@VariantSupport()
class TestAccumulateNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"to_add": ("*",),
},
"optional": {
"accumulation": ("ACCUMULATION",),
},
}
RETURN_TYPES = ("ACCUMULATION",)
FUNCTION = "accumulate"
CATEGORY = "Testing/Lists"
def accumulate(self, to_add, accumulation = None):
if accumulation is None:
value = [to_add]
else:
value = accumulation["accum"] + [to_add]
return ({"accum": value},)
@VariantSupport()
class TestAccumulationHeadNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"accumulation": ("ACCUMULATION",),
},
}
RETURN_TYPES = ("ACCUMULATION", "*",)
FUNCTION = "accumulation_head"
CATEGORY = "Testing/Lists"
def accumulation_head(self, accumulation):
accum = accumulation["accum"]
if len(accum) == 0:
return (accumulation, None)
else:
return ({"accum": accum[1:]}, accum[0])
class TestAccumulationTailNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"accumulation": ("ACCUMULATION",),
},
}
RETURN_TYPES = ("ACCUMULATION", "*",)
FUNCTION = "accumulation_tail"
CATEGORY = "Testing/Lists"
def accumulation_tail(self, accumulation):
accum = accumulation["accum"]
if len(accum) == 0:
return (None, accumulation)
else:
return ({"accum": accum[:-1]}, accum[-1])
@VariantSupport()
class TestAccumulationToListNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"accumulation": ("ACCUMULATION",),
},
}
RETURN_TYPES = ("*",)
OUTPUT_IS_LIST = (True,)
FUNCTION = "accumulation_to_list"
CATEGORY = "Testing/Lists"
def accumulation_to_list(self, accumulation):
return (accumulation["accum"],)
@VariantSupport()
class TestListToAccumulationNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"list": ("*",),
},
}
RETURN_TYPES = ("ACCUMULATION",)
INPUT_IS_LIST = (True,)
FUNCTION = "list_to_accumulation"
CATEGORY = "Testing/Lists"
def list_to_accumulation(self, list):
return ({"accum": list},)
@VariantSupport()
class TestAccumulationGetLengthNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"accumulation": ("ACCUMULATION",),
},
}
RETURN_TYPES = ("INT",)
FUNCTION = "accumlength"
CATEGORY = "Testing/Lists"
def accumlength(self, accumulation):
return (len(accumulation['accum']),)
@VariantSupport()
class TestAccumulationGetItemNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"accumulation": ("ACCUMULATION",),
"index": ("INT", {"default":0, "step":1})
},
}
RETURN_TYPES = ("*",)
FUNCTION = "get_item"
CATEGORY = "Testing/Lists"
def get_item(self, accumulation, index):
return (accumulation['accum'][index],)
@VariantSupport()
class TestAccumulationSetItemNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"accumulation": ("ACCUMULATION",),
"index": ("INT", {"default":0, "step":1}),
"value": ("*",),
},
}
RETURN_TYPES = ("ACCUMULATION",)
FUNCTION = "set_item"
CATEGORY = "Testing/Lists"
def set_item(self, accumulation, index, value):
new_accum = accumulation['accum'][:]
new_accum[index] = value
return ({"accum": new_accum},)
class TestIntMathOperation:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
"b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
"operation": (["add", "subtract", "multiply", "divide", "modulo", "power"],),
},
}
RETURN_TYPES = ("INT",)
FUNCTION = "int_math_operation"
CATEGORY = "Testing/Logic"
def int_math_operation(self, a, b, operation):
if operation == "add":
return (a + b,)
elif operation == "subtract":
return (a - b,)
elif operation == "multiply":
return (a * b,)
elif operation == "divide":
return (a // b,)
elif operation == "modulo":
return (a % b,)
elif operation == "power":
return (a ** b,)
from .flow_control import NUM_FLOW_SOCKETS
@VariantSupport()
class TestForLoopOpen:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"remaining": ("INT", {"default": 1, "min": 0, "max": 100000, "step": 1}),
},
"optional": {
f"initial_value{i}": ("*",) for i in range(1, NUM_FLOW_SOCKETS)
},
"hidden": {
"initial_value0": ("*",)
}
}
RETURN_TYPES = tuple(["FLOW_CONTROL", "INT",] + ["*"] * (NUM_FLOW_SOCKETS-1))
RETURN_NAMES = tuple(["flow_control", "remaining"] + [f"value{i}" for i in range(1, NUM_FLOW_SOCKETS)])
FUNCTION = "for_loop_open"
CATEGORY = "Testing/Flow"
def for_loop_open(self, remaining, **kwargs):
graph = GraphBuilder()
if "initial_value0" in kwargs:
remaining = kwargs["initial_value0"]
graph.node("TestWhileLoopOpen", condition=remaining, initial_value0=remaining, **{(f"initial_value{i}"): kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)})
outputs = [kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)]
return {
"result": tuple(["stub", remaining] + outputs),
"expand": graph.finalize(),
}
@VariantSupport()
class TestForLoopClose:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"flow_control": ("FLOW_CONTROL", {"rawLink": True}),
},
"optional": {
f"initial_value{i}": ("*",{"rawLink": True}) for i in range(1, NUM_FLOW_SOCKETS)
},
}
RETURN_TYPES = tuple(["*"] * (NUM_FLOW_SOCKETS-1))
RETURN_NAMES = tuple([f"value{i}" for i in range(1, NUM_FLOW_SOCKETS)])
FUNCTION = "for_loop_close"
CATEGORY = "Testing/Flow"
def for_loop_close(self, flow_control, **kwargs):
graph = GraphBuilder()
while_open = flow_control[0]
sub = graph.node("TestIntMathOperation", operation="subtract", a=[while_open,1], b=1)
cond = graph.node("TestToBoolNode", value=sub.out(0))
input_values = {f"initial_value{i}": kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)}
while_close = graph.node("TestWhileLoopClose",
flow_control=flow_control,
condition=cond.out(0),
initial_value0=sub.out(0),
**input_values)
return {
"result": tuple([while_close.out(i) for i in range(1, NUM_FLOW_SOCKETS)]),
"expand": graph.finalize(),
}
NUM_LIST_SOCKETS = 10
@VariantSupport()
class TestMakeListNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value1": ("*",),
},
"optional": {
f"value{i}": ("*",) for i in range(1, NUM_LIST_SOCKETS)
},
}
RETURN_TYPES = ("*",)
FUNCTION = "make_list"
OUTPUT_IS_LIST = (True,)
CATEGORY = "Testing/Lists"
def make_list(self, **kwargs):
result = []
for i in range(NUM_LIST_SOCKETS):
if f"value{i}" in kwargs:
result.append(kwargs[f"value{i}"])
return (result,)
UTILITY_NODE_CLASS_MAPPINGS = {
"TestAccumulateNode": TestAccumulateNode,
"TestAccumulationHeadNode": TestAccumulationHeadNode,
"TestAccumulationTailNode": TestAccumulationTailNode,
"TestAccumulationToListNode": TestAccumulationToListNode,
"TestListToAccumulationNode": TestListToAccumulationNode,
"TestAccumulationGetLengthNode": TestAccumulationGetLengthNode,
"TestAccumulationGetItemNode": TestAccumulationGetItemNode,
"TestAccumulationSetItemNode": TestAccumulationSetItemNode,
"TestForLoopOpen": TestForLoopOpen,
"TestForLoopClose": TestForLoopClose,
"TestIntMathOperation": TestIntMathOperation,
"TestMakeListNode": TestMakeListNode,
}
UTILITY_NODE_DISPLAY_NAME_MAPPINGS = {
"TestAccumulateNode": "Accumulate",
"TestAccumulationHeadNode": "Accumulation Head",
"TestAccumulationTailNode": "Accumulation Tail",
"TestAccumulationToListNode": "Accumulation to List",
"TestListToAccumulationNode": "List to Accumulation",
"TestAccumulationGetLengthNode": "Accumulation Get Length",
"TestAccumulationGetItemNode": "Accumulation Get Item",
"TestAccumulationSetItemNode": "Accumulation Set Item",
"TestForLoopOpen": "For Loop Open",
"TestForLoopClose": "For Loop Close",
"TestIntMathOperation": "Int Math Operation",
"TestMakeListNode": "Make List",
}