mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 21:16:09 +00:00
Fix showing progress from other sessions
Because `client_id` was missing from ths `progress_state` message, it was being sent to all connected sessions. This technically meant that if someone had a graph with the same nodes, they would see the progress updates for others. Also added a test to prevent reoccurance and moved the tests around to make CI easier to hook up.
This commit is contained in:
28
tests/execution/testing_nodes/testing-pack/__init__.py
Normal file
28
tests/execution/testing_nodes/testing-pack/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from .specific_tests import TEST_NODE_CLASS_MAPPINGS, TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .api_test_nodes import API_TEST_NODE_CLASS_MAPPINGS, API_TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
|
||||
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
NODE_CLASS_MAPPINGS.update(TEST_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(API_TEST_NODE_CLASS_MAPPINGS)
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(API_TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
78
tests/execution/testing_nodes/testing-pack/api_test_nodes.py
Normal file
78
tests/execution/testing_nodes/testing-pack/api_test_nodes.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import asyncio
|
||||
import time
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
from comfy_api.v0_0_2 import ComfyAPI, ComfyAPISync
|
||||
|
||||
api = ComfyAPI()
|
||||
api_sync = ComfyAPISync()
|
||||
|
||||
|
||||
class TestAsyncProgressUpdate(ComfyNodeABC):
|
||||
"""Test node with async VALIDATE_INPUTS."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"sleep_seconds": (IO.FLOAT, {"default": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def execute(self, value, sleep_seconds):
|
||||
start = time.time()
|
||||
expiration = start + sleep_seconds
|
||||
now = start
|
||||
while now < expiration:
|
||||
now = time.time()
|
||||
await api.execution.set_progress(
|
||||
value=(now - start) / sleep_seconds,
|
||||
max_value=1.0,
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
return (value,)
|
||||
|
||||
|
||||
class TestSyncProgressUpdate(ComfyNodeABC):
|
||||
"""Test node with async VALIDATE_INPUTS."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"sleep_seconds": (IO.FLOAT, {"default": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
def execute(self, value, sleep_seconds):
|
||||
start = time.time()
|
||||
expiration = start + sleep_seconds
|
||||
now = start
|
||||
while now < expiration:
|
||||
now = time.time()
|
||||
api_sync.execution.set_progress(
|
||||
value=(now - start) / sleep_seconds,
|
||||
max_value=1.0,
|
||||
)
|
||||
time.sleep(0.01)
|
||||
return (value,)
|
||||
|
||||
|
||||
API_TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestAsyncProgressUpdate": TestAsyncProgressUpdate,
|
||||
"TestSyncProgressUpdate": TestSyncProgressUpdate,
|
||||
}
|
||||
|
||||
API_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestAsyncProgressUpdate": "Async Progress Update Test Node",
|
||||
"TestSyncProgressUpdate": "Sync Progress Update Test Node",
|
||||
}
|
343
tests/execution/testing_nodes/testing-pack/async_test_nodes.py
Normal file
343
tests/execution/testing_nodes/testing-pack/async_test_nodes.py
Normal file
@@ -0,0 +1,343 @@
|
||||
import torch
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
|
||||
|
||||
class TestAsyncValidation(ComfyNodeABC):
|
||||
"""Test node with async VALIDATE_INPUTS."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 5.0}),
|
||||
"threshold": ("FLOAT", {"default": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
@classmethod
|
||||
async def VALIDATE_INPUTS(cls, value, threshold):
|
||||
# Simulate async validation (e.g., checking remote service)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
if value > threshold:
|
||||
return f"Value {value} exceeds threshold {threshold}"
|
||||
return True
|
||||
|
||||
def process(self, value, threshold):
|
||||
# Create image based on value
|
||||
intensity = value / 10.0
|
||||
image = torch.ones([1, 512, 512, 3]) * intensity
|
||||
return (image,)
|
||||
|
||||
|
||||
class TestAsyncError(ComfyNodeABC):
|
||||
"""Test node that errors during async execution."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"error_after": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "error_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def error_execution(self, value, error_after):
|
||||
await asyncio.sleep(error_after)
|
||||
raise RuntimeError("Intentional async execution error for testing")
|
||||
|
||||
|
||||
class TestAsyncValidationError(ComfyNodeABC):
|
||||
"""Test node with async validation that always fails."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 5.0}),
|
||||
"max_value": ("FLOAT", {"default": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
@classmethod
|
||||
async def VALIDATE_INPUTS(cls, value, max_value):
|
||||
await asyncio.sleep(0.05)
|
||||
# Always fail validation for values > max_value
|
||||
if value > max_value:
|
||||
return f"Async validation failed: {value} > {max_value}"
|
||||
return True
|
||||
|
||||
def process(self, value, max_value):
|
||||
# This won't be reached if validation fails
|
||||
image = torch.ones([1, 512, 512, 3]) * (value / max_value)
|
||||
return (image,)
|
||||
|
||||
|
||||
class TestAsyncTimeout(ComfyNodeABC):
|
||||
"""Test node that simulates timeout scenarios."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"timeout": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}),
|
||||
"operation_time": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "timeout_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def timeout_execution(self, value, timeout, operation_time):
|
||||
try:
|
||||
# This will timeout if operation_time > timeout
|
||||
await asyncio.wait_for(asyncio.sleep(operation_time), timeout=timeout)
|
||||
return (value,)
|
||||
except asyncio.TimeoutError:
|
||||
raise RuntimeError(f"Operation timed out after {timeout} seconds")
|
||||
|
||||
|
||||
class TestSyncError(ComfyNodeABC):
|
||||
"""Test node that errors synchronously (for mixed sync/async testing)."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "sync_error"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
def sync_error(self, value):
|
||||
raise RuntimeError("Intentional sync execution error for testing")
|
||||
|
||||
|
||||
class TestAsyncLazyCheck(ComfyNodeABC):
|
||||
"""Test node with async check_lazy_status."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": (IO.ANY, {"lazy": True}),
|
||||
"input2": (IO.ANY, {"lazy": True}),
|
||||
"condition": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def check_lazy_status(self, condition, input1, input2):
|
||||
# Simulate async checking (e.g., querying remote service)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
needed = []
|
||||
if condition and input1 is None:
|
||||
needed.append("input1")
|
||||
if not condition and input2 is None:
|
||||
needed.append("input2")
|
||||
return needed
|
||||
|
||||
def process(self, input1, input2, condition):
|
||||
# Return a simple image
|
||||
return (torch.ones([1, 512, 512, 3]),)
|
||||
|
||||
|
||||
class TestDynamicAsyncGeneration(ComfyNodeABC):
|
||||
"""Test node that dynamically generates async nodes."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",),
|
||||
"image2": ("IMAGE",),
|
||||
"num_async_nodes": ("INT", {"default": 3, "min": 1, "max": 10}),
|
||||
"sleep_duration": ("FLOAT", {"default": 0.2, "min": 0.1, "max": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "generate_async_workflow"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create multiple async sleep nodes
|
||||
sleep_nodes = []
|
||||
for i in range(num_async_nodes):
|
||||
image = image1 if i % 2 == 0 else image2
|
||||
sleep_node = g.node("TestSleep", value=image, seconds=sleep_duration)
|
||||
sleep_nodes.append(sleep_node)
|
||||
|
||||
# Average all results
|
||||
if len(sleep_nodes) == 1:
|
||||
final_node = sleep_nodes[0]
|
||||
else:
|
||||
avg_inputs = {"input1": sleep_nodes[0].out(0)}
|
||||
for i, node in enumerate(sleep_nodes[1:], 2):
|
||||
avg_inputs[f"input{i}"] = node.out(0)
|
||||
final_node = g.node("TestVariadicAverage", **avg_inputs)
|
||||
|
||||
return {
|
||||
"result": (final_node.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
|
||||
class TestAsyncResourceUser(ComfyNodeABC):
|
||||
"""Test node that uses resources during async execution."""
|
||||
|
||||
# Class-level resource tracking for testing
|
||||
_active_resources: Dict[str, bool] = {}
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"resource_id": ("STRING", {"default": "resource_0"}),
|
||||
"duration": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "use_resource"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def use_resource(self, value, resource_id, duration):
|
||||
# Check if resource is already in use
|
||||
if self._active_resources.get(resource_id, False):
|
||||
raise RuntimeError(f"Resource {resource_id} is already in use!")
|
||||
|
||||
# Mark resource as in use
|
||||
self._active_resources[resource_id] = True
|
||||
|
||||
try:
|
||||
# Simulate resource usage
|
||||
await asyncio.sleep(duration)
|
||||
return (value,)
|
||||
finally:
|
||||
# Always clean up resource
|
||||
self._active_resources[resource_id] = False
|
||||
|
||||
|
||||
class TestAsyncBatchProcessing(ComfyNodeABC):
|
||||
"""Test async processing of batched inputs."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
"process_time_per_item": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 1.0}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process_batch"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def process_batch(self, images, process_time_per_item, unique_id):
|
||||
batch_size = images.shape[0]
|
||||
pbar = ProgressBar(batch_size, node_id=unique_id)
|
||||
|
||||
# Process each image in the batch
|
||||
processed = []
|
||||
for i in range(batch_size):
|
||||
# Simulate async processing
|
||||
await asyncio.sleep(process_time_per_item)
|
||||
|
||||
# Simple processing: invert the image
|
||||
processed_image = 1.0 - images[i:i+1]
|
||||
processed.append(processed_image)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
# Stack processed images
|
||||
result = torch.cat(processed, dim=0)
|
||||
return (result,)
|
||||
|
||||
|
||||
class TestAsyncConcurrentLimit(ComfyNodeABC):
|
||||
"""Test concurrent execution limits for async nodes."""
|
||||
|
||||
_semaphore = asyncio.Semaphore(2) # Only allow 2 concurrent executions
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"duration": ("FLOAT", {"default": 0.5, "min": 0.1, "max": 2.0}),
|
||||
"node_id": ("INT", {"default": 0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "limited_execution"
|
||||
CATEGORY = "_for_testing/async"
|
||||
|
||||
async def limited_execution(self, value, duration, node_id):
|
||||
async with self._semaphore:
|
||||
# Node {node_id} acquired semaphore
|
||||
await asyncio.sleep(duration)
|
||||
# Node {node_id} releasing semaphore
|
||||
return (value,)
|
||||
|
||||
|
||||
# Add node mappings
|
||||
ASYNC_TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestAsyncValidation": TestAsyncValidation,
|
||||
"TestAsyncError": TestAsyncError,
|
||||
"TestAsyncValidationError": TestAsyncValidationError,
|
||||
"TestAsyncTimeout": TestAsyncTimeout,
|
||||
"TestSyncError": TestSyncError,
|
||||
"TestAsyncLazyCheck": TestAsyncLazyCheck,
|
||||
"TestDynamicAsyncGeneration": TestDynamicAsyncGeneration,
|
||||
"TestAsyncResourceUser": TestAsyncResourceUser,
|
||||
"TestAsyncBatchProcessing": TestAsyncBatchProcessing,
|
||||
"TestAsyncConcurrentLimit": TestAsyncConcurrentLimit,
|
||||
}
|
||||
|
||||
ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestAsyncValidation": "Test Async Validation",
|
||||
"TestAsyncError": "Test Async Error",
|
||||
"TestAsyncValidationError": "Test Async Validation Error",
|
||||
"TestAsyncTimeout": "Test Async Timeout",
|
||||
"TestSyncError": "Test Sync Error",
|
||||
"TestAsyncLazyCheck": "Test Async Lazy Check",
|
||||
"TestDynamicAsyncGeneration": "Test Dynamic Async Generation",
|
||||
"TestAsyncResourceUser": "Test Async Resource User",
|
||||
"TestAsyncBatchProcessing": "Test Async Batch Processing",
|
||||
"TestAsyncConcurrentLimit": "Test Async Concurrent Limit",
|
||||
}
|
194
tests/execution/testing_nodes/testing-pack/conditions.py
Normal file
194
tests/execution/testing_nodes/testing-pack/conditions.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import re
|
||||
import torch
|
||||
|
||||
class TestIntConditions:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
|
||||
"b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
|
||||
"operation": (["==", "!=", "<", ">", "<=", ">="],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOOLEAN",)
|
||||
FUNCTION = "int_condition"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def int_condition(self, a, b, operation):
|
||||
if operation == "==":
|
||||
return (a == b,)
|
||||
elif operation == "!=":
|
||||
return (a != b,)
|
||||
elif operation == "<":
|
||||
return (a < b,)
|
||||
elif operation == ">":
|
||||
return (a > b,)
|
||||
elif operation == "<=":
|
||||
return (a <= b,)
|
||||
elif operation == ">=":
|
||||
return (a >= b,)
|
||||
|
||||
|
||||
class TestFloatConditions:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"a": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}),
|
||||
"b": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}),
|
||||
"operation": (["==", "!=", "<", ">", "<=", ">="],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOOLEAN",)
|
||||
FUNCTION = "float_condition"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def float_condition(self, a, b, operation):
|
||||
if operation == "==":
|
||||
return (a == b,)
|
||||
elif operation == "!=":
|
||||
return (a != b,)
|
||||
elif operation == "<":
|
||||
return (a < b,)
|
||||
elif operation == ">":
|
||||
return (a > b,)
|
||||
elif operation == "<=":
|
||||
return (a <= b,)
|
||||
elif operation == ">=":
|
||||
return (a >= b,)
|
||||
|
||||
class TestStringConditions:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"a": ("STRING", {"multiline": False}),
|
||||
"b": ("STRING", {"multiline": False}),
|
||||
"operation": (["a == b", "a != b", "a IN b", "a MATCH REGEX(b)", "a BEGINSWITH b", "a ENDSWITH b"],),
|
||||
"case_sensitive": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOOLEAN",)
|
||||
FUNCTION = "string_condition"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def string_condition(self, a, b, operation, case_sensitive):
|
||||
if not case_sensitive:
|
||||
a = a.lower()
|
||||
b = b.lower()
|
||||
|
||||
if operation == "a == b":
|
||||
return (a == b,)
|
||||
elif operation == "a != b":
|
||||
return (a != b,)
|
||||
elif operation == "a IN b":
|
||||
return (a in b,)
|
||||
elif operation == "a MATCH REGEX(b)":
|
||||
try:
|
||||
return (re.match(b, a) is not None,)
|
||||
except:
|
||||
return (False,)
|
||||
elif operation == "a BEGINSWITH b":
|
||||
return (a.startswith(b),)
|
||||
elif operation == "a ENDSWITH b":
|
||||
return (a.endswith(b),)
|
||||
|
||||
class TestToBoolNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("*",),
|
||||
},
|
||||
"optional": {
|
||||
"invert": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOOLEAN",)
|
||||
FUNCTION = "to_bool"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def to_bool(self, value, invert = False):
|
||||
if isinstance(value, torch.Tensor):
|
||||
if value.max().item() == 0 and value.min().item() == 0:
|
||||
result = False
|
||||
else:
|
||||
result = True
|
||||
else:
|
||||
try:
|
||||
result = bool(value)
|
||||
except:
|
||||
# Can't convert it? Well then it's something or other. I dunno, I'm not a Python programmer.
|
||||
result = True
|
||||
|
||||
if invert:
|
||||
result = not result
|
||||
|
||||
return (result,)
|
||||
|
||||
class TestBoolOperationNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"a": ("BOOLEAN",),
|
||||
"b": ("BOOLEAN",),
|
||||
"op": (["a AND b", "a OR b", "a XOR b", "NOT a"],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOOLEAN",)
|
||||
FUNCTION = "bool_operation"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def bool_operation(self, a, b, op):
|
||||
if op == "a AND b":
|
||||
return (a and b,)
|
||||
elif op == "a OR b":
|
||||
return (a or b,)
|
||||
elif op == "a XOR b":
|
||||
return (a ^ b,)
|
||||
elif op == "NOT a":
|
||||
return (not a,)
|
||||
|
||||
|
||||
CONDITION_NODE_CLASS_MAPPINGS = {
|
||||
"TestIntConditions": TestIntConditions,
|
||||
"TestFloatConditions": TestFloatConditions,
|
||||
"TestStringConditions": TestStringConditions,
|
||||
"TestToBoolNode": TestToBoolNode,
|
||||
"TestBoolOperationNode": TestBoolOperationNode,
|
||||
}
|
||||
|
||||
CONDITION_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestIntConditions": "Int Condition",
|
||||
"TestFloatConditions": "Float Condition",
|
||||
"TestStringConditions": "String Condition",
|
||||
"TestToBoolNode": "To Bool",
|
||||
"TestBoolOperationNode": "Bool Operation",
|
||||
}
|
173
tests/execution/testing_nodes/testing-pack/flow_control.py
Normal file
173
tests/execution/testing_nodes/testing-pack/flow_control.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.graph import ExecutionBlocker
|
||||
from .tools import VariantSupport
|
||||
|
||||
NUM_FLOW_SOCKETS = 5
|
||||
@VariantSupport()
|
||||
class TestWhileLoopOpen:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
inputs = {
|
||||
"required": {
|
||||
"condition": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
}
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
inputs["optional"][f"initial_value{i}"] = ("*",)
|
||||
return inputs
|
||||
|
||||
RETURN_TYPES = tuple(["FLOW_CONTROL"] + ["*"] * NUM_FLOW_SOCKETS)
|
||||
RETURN_NAMES = tuple(["FLOW_CONTROL"] + [f"value{i}" for i in range(NUM_FLOW_SOCKETS)])
|
||||
FUNCTION = "while_loop_open"
|
||||
|
||||
CATEGORY = "Testing/Flow"
|
||||
|
||||
def while_loop_open(self, condition, **kwargs):
|
||||
values = []
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
values.append(kwargs.get(f"initial_value{i}", None))
|
||||
return tuple(["stub"] + values)
|
||||
|
||||
@VariantSupport()
|
||||
class TestWhileLoopClose:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
inputs = {
|
||||
"required": {
|
||||
"flow_control": ("FLOW_CONTROL", {"rawLink": True}),
|
||||
"condition": ("BOOLEAN", {"forceInput": True}),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"hidden": {
|
||||
"dynprompt": "DYNPROMPT",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
}
|
||||
}
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
inputs["optional"][f"initial_value{i}"] = ("*",)
|
||||
return inputs
|
||||
|
||||
RETURN_TYPES = tuple(["*"] * NUM_FLOW_SOCKETS)
|
||||
RETURN_NAMES = tuple([f"value{i}" for i in range(NUM_FLOW_SOCKETS)])
|
||||
FUNCTION = "while_loop_close"
|
||||
|
||||
CATEGORY = "Testing/Flow"
|
||||
|
||||
def explore_dependencies(self, node_id, dynprompt, upstream):
|
||||
node_info = dynprompt.get_node(node_id)
|
||||
if "inputs" not in node_info:
|
||||
return
|
||||
for k, v in node_info["inputs"].items():
|
||||
if is_link(v):
|
||||
parent_id = v[0]
|
||||
if parent_id not in upstream:
|
||||
upstream[parent_id] = []
|
||||
self.explore_dependencies(parent_id, dynprompt, upstream)
|
||||
upstream[parent_id].append(node_id)
|
||||
|
||||
def collect_contained(self, node_id, upstream, contained):
|
||||
if node_id not in upstream:
|
||||
return
|
||||
for child_id in upstream[node_id]:
|
||||
if child_id not in contained:
|
||||
contained[child_id] = True
|
||||
self.collect_contained(child_id, upstream, contained)
|
||||
|
||||
|
||||
def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=None, **kwargs):
|
||||
assert dynprompt is not None
|
||||
if not condition:
|
||||
# We're done with the loop
|
||||
values = []
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
values.append(kwargs.get(f"initial_value{i}", None))
|
||||
return tuple(values)
|
||||
|
||||
# We want to loop
|
||||
upstream = {}
|
||||
# Get the list of all nodes between the open and close nodes
|
||||
self.explore_dependencies(unique_id, dynprompt, upstream)
|
||||
|
||||
contained = {}
|
||||
open_node = flow_control[0]
|
||||
self.collect_contained(open_node, upstream, contained)
|
||||
contained[unique_id] = True
|
||||
contained[open_node] = True
|
||||
|
||||
# We'll use the default prefix, but to avoid having node names grow exponentially in size,
|
||||
# we'll use "Recurse" for the name of the recursively-generated copy of this node.
|
||||
graph = GraphBuilder()
|
||||
for node_id in contained:
|
||||
original_node = dynprompt.get_node(node_id)
|
||||
node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id)
|
||||
node.set_override_display_id(node_id)
|
||||
for node_id in contained:
|
||||
original_node = dynprompt.get_node(node_id)
|
||||
node = graph.lookup_node("Recurse" if node_id == unique_id else node_id)
|
||||
assert node is not None
|
||||
for k, v in original_node["inputs"].items():
|
||||
if is_link(v) and v[0] in contained:
|
||||
parent = graph.lookup_node(v[0])
|
||||
assert parent is not None
|
||||
node.set_input(k, parent.out(v[1]))
|
||||
else:
|
||||
node.set_input(k, v)
|
||||
new_open = graph.lookup_node(open_node)
|
||||
assert new_open is not None
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
key = f"initial_value{i}"
|
||||
new_open.set_input(key, kwargs.get(key, None))
|
||||
my_clone = graph.lookup_node("Recurse")
|
||||
assert my_clone is not None
|
||||
result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS))
|
||||
return {
|
||||
"result": tuple(result),
|
||||
"expand": graph.finalize(),
|
||||
}
|
||||
|
||||
@VariantSupport()
|
||||
class TestExecutionBlockerNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
inputs = {
|
||||
"required": {
|
||||
"input": ("*",),
|
||||
"block": ("BOOLEAN",),
|
||||
"verbose": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
}
|
||||
return inputs
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
RETURN_NAMES = ("output",)
|
||||
FUNCTION = "execution_blocker"
|
||||
|
||||
CATEGORY = "Testing/Flow"
|
||||
|
||||
def execution_blocker(self, input, block, verbose):
|
||||
if block:
|
||||
return (ExecutionBlocker("Blocked Execution" if verbose else None),)
|
||||
return (input,)
|
||||
|
||||
FLOW_CONTROL_NODE_CLASS_MAPPINGS = {
|
||||
"TestWhileLoopOpen": TestWhileLoopOpen,
|
||||
"TestWhileLoopClose": TestWhileLoopClose,
|
||||
"TestExecutionBlocker": TestExecutionBlockerNode,
|
||||
}
|
||||
FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestWhileLoopOpen": "While Loop Open",
|
||||
"TestWhileLoopClose": "While Loop Close",
|
||||
"TestExecutionBlocker": "Execution Blocker",
|
||||
}
|
519
tests/execution/testing_nodes/testing-pack/specific_tests.py
Normal file
519
tests/execution/testing_nodes/testing-pack/specific_tests.py
Normal file
@@ -0,0 +1,519 @@
|
||||
import torch
|
||||
import time
|
||||
import asyncio
|
||||
from comfy.utils import ProgressBar
|
||||
from .tools import VariantSupport
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
|
||||
class TestLazyMixImages:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",{"lazy": True}),
|
||||
"image2": ("IMAGE",{"lazy": True}),
|
||||
"mask": ("MASK",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "mix"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def check_lazy_status(self, mask, image1, image2):
|
||||
mask_min = mask.min()
|
||||
mask_max = mask.max()
|
||||
needed = []
|
||||
if image1 is None and (mask_min != 1.0 or mask_max != 1.0):
|
||||
needed.append("image1")
|
||||
if image2 is None and (mask_min != 0.0 or mask_max != 0.0):
|
||||
needed.append("image2")
|
||||
return needed
|
||||
|
||||
# Not trying to handle different batch sizes here just to keep the demo simple
|
||||
def mix(self, mask, image1, image2):
|
||||
mask_min = mask.min()
|
||||
mask_max = mask.max()
|
||||
if mask_min == 0.0 and mask_max == 0.0:
|
||||
return (image1,)
|
||||
elif mask_min == 1.0 and mask_max == 1.0:
|
||||
return (image2,)
|
||||
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if len(mask.shape) == 3:
|
||||
mask = mask.unsqueeze(3)
|
||||
if mask.shape[3] < image1.shape[3]:
|
||||
mask = mask.repeat(1, 1, 1, image1.shape[3])
|
||||
|
||||
result = image1 * (1. - mask) + image2 * mask,
|
||||
return (result[0],)
|
||||
|
||||
class TestVariadicAverage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "variadic_average"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def variadic_average(self, input1, **kwargs):
|
||||
inputs = [input1]
|
||||
while 'input' + str(len(inputs) + 1) in kwargs:
|
||||
inputs.append(kwargs['input' + str(len(inputs) + 1)])
|
||||
return (torch.stack(inputs).mean(dim=0),)
|
||||
|
||||
|
||||
class TestCustomIsChanged:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
"optional": {
|
||||
"should_change": ("BOOL", {"default": False}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_is_changed"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_is_changed(self, image, should_change=False):
|
||||
return (image,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, should_change=False, *args, **kwargs):
|
||||
if should_change:
|
||||
return float("NaN")
|
||||
else:
|
||||
return False
|
||||
|
||||
class TestIsChangedWithConstants:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_is_changed"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_is_changed(self, image, value):
|
||||
return (image * value,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, image, value):
|
||||
if image is None:
|
||||
return value
|
||||
else:
|
||||
return image.mean().item() * value
|
||||
|
||||
class TestCustomValidation1:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("IMAGE,FLOAT",),
|
||||
"input2": ("IMAGE,FLOAT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_validation1"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_validation1(self, input1, input2):
|
||||
if isinstance(input1, float) and isinstance(input2, float):
|
||||
result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
||||
else:
|
||||
result = input1 * input2
|
||||
return (result,)
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, input1=None, input2=None):
|
||||
if input1 is not None:
|
||||
if not isinstance(input1, (torch.Tensor, float)):
|
||||
return f"Invalid type of input1: {type(input1)}"
|
||||
if input2 is not None:
|
||||
if not isinstance(input2, (torch.Tensor, float)):
|
||||
return f"Invalid type of input2: {type(input2)}"
|
||||
|
||||
return True
|
||||
|
||||
class TestCustomValidation2:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("IMAGE,FLOAT",),
|
||||
"input2": ("IMAGE,FLOAT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_validation2"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_validation2(self, input1, input2):
|
||||
if isinstance(input1, float) and isinstance(input2, float):
|
||||
result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
||||
else:
|
||||
result = input1 * input2
|
||||
return (result,)
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None):
|
||||
if input1 is not None:
|
||||
if not isinstance(input1, (torch.Tensor, float)):
|
||||
return f"Invalid type of input1: {type(input1)}"
|
||||
if input2 is not None:
|
||||
if not isinstance(input2, (torch.Tensor, float)):
|
||||
return f"Invalid type of input2: {type(input2)}"
|
||||
|
||||
if 'input1' in input_types:
|
||||
if input_types['input1'] not in ["IMAGE", "FLOAT"]:
|
||||
return f"Invalid type of input1: {input_types['input1']}"
|
||||
if 'input2' in input_types:
|
||||
if input_types['input2'] not in ["IMAGE", "FLOAT"]:
|
||||
return f"Invalid type of input2: {input_types['input2']}"
|
||||
|
||||
return True
|
||||
|
||||
@VariantSupport()
|
||||
class TestCustomValidation3:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("IMAGE,FLOAT",),
|
||||
"input2": ("IMAGE,FLOAT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_validation3"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_validation3(self, input1, input2):
|
||||
if isinstance(input1, float) and isinstance(input2, float):
|
||||
result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
||||
else:
|
||||
result = input1 * input2
|
||||
return (result,)
|
||||
|
||||
class TestCustomValidation4:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("FLOAT",),
|
||||
"input2": ("FLOAT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_validation4"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_validation4(self, input1, input2):
|
||||
result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
||||
return (result,)
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, input1, input2):
|
||||
if input1 is not None:
|
||||
if not isinstance(input1, float):
|
||||
return f"Invalid type of input1: {type(input1)}"
|
||||
if input2 is not None:
|
||||
if not isinstance(input2, float):
|
||||
return f"Invalid type of input2: {type(input2)}"
|
||||
|
||||
return True
|
||||
|
||||
class TestCustomValidation5:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("FLOAT", {"min": 0.0, "max": 1.0}),
|
||||
"input2": ("FLOAT", {"min": 0.0, "max": 1.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_validation5"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_validation5(self, input1, input2):
|
||||
value = input1 * input2
|
||||
return (torch.ones([1, 512, 512, 3]) * value,)
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, **kwargs):
|
||||
if kwargs['input2'] == 7.0:
|
||||
return "7s are not allowed. I've never liked 7s."
|
||||
return True
|
||||
|
||||
class TestDynamicDependencyCycle:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("IMAGE",),
|
||||
"input2": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "dynamic_dependency_cycle"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def dynamic_dependency_cycle(self, input1, input2):
|
||||
g = GraphBuilder()
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
mix1 = g.node("TestLazyMixImages", image1=input1, mask=mask.out(0))
|
||||
mix2 = g.node("TestLazyMixImages", image1=mix1.out(0), image2=input2, mask=mask.out(0))
|
||||
|
||||
# Create the cyle
|
||||
mix1.set_input("image2", mix2.out(0))
|
||||
|
||||
return {
|
||||
"result": (mix2.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestMixedExpansionReturns:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("FLOAT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE","IMAGE")
|
||||
FUNCTION = "mixed_expansion_returns"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def mixed_expansion_returns(self, input1):
|
||||
white_image = torch.ones([1, 512, 512, 3])
|
||||
if input1 <= 0.1:
|
||||
return (torch.ones([1, 512, 512, 3]) * 0.1, white_image)
|
||||
elif input1 <= 0.2:
|
||||
return {
|
||||
"result": (torch.ones([1, 512, 512, 3]) * 0.2, white_image),
|
||||
}
|
||||
else:
|
||||
g = GraphBuilder()
|
||||
mask = g.node("StubMask", value=0.3, height=512, width=512, batch_size=1)
|
||||
black = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
white = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mix = g.node("TestLazyMixImages", image1=black.out(0), image2=white.out(0), mask=mask.out(0))
|
||||
return {
|
||||
"result": (mix.out(0), white_image),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestSamplingInExpansion:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"clip": ("CLIP",),
|
||||
"vae": ("VAE",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 100}),
|
||||
"cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 30.0}),
|
||||
"prompt": ("STRING", {"multiline": True, "default": "a beautiful landscape with mountains and trees"}),
|
||||
"negative_prompt": ("STRING", {"multiline": True, "default": "blurry, bad quality, worst quality"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "sampling_in_expansion"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def sampling_in_expansion(self, model, clip, vae, seed, steps, cfg, prompt, negative_prompt):
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create a basic image generation workflow using the input model, clip and vae
|
||||
# 1. Setup text prompts using the provided CLIP model
|
||||
positive_prompt = g.node("CLIPTextEncode",
|
||||
text=prompt,
|
||||
clip=clip)
|
||||
negative_prompt = g.node("CLIPTextEncode",
|
||||
text=negative_prompt,
|
||||
clip=clip)
|
||||
|
||||
# 2. Create empty latent with specified size
|
||||
empty_latent = g.node("EmptyLatentImage", width=512, height=512, batch_size=1)
|
||||
|
||||
# 3. Setup sampler and generate image latent
|
||||
sampler = g.node("KSampler",
|
||||
model=model,
|
||||
positive=positive_prompt.out(0),
|
||||
negative=negative_prompt.out(0),
|
||||
latent_image=empty_latent.out(0),
|
||||
seed=seed,
|
||||
steps=steps,
|
||||
cfg=cfg,
|
||||
sampler_name="euler_ancestral",
|
||||
scheduler="normal")
|
||||
|
||||
# 4. Decode latent to image using VAE
|
||||
output = g.node("VAEDecode", samples=sampler.out(0), vae=vae)
|
||||
|
||||
return {
|
||||
"result": (output.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestSleep(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": (IO.ANY, {}),
|
||||
"seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "sleep"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
async def sleep(self, value, seconds, unique_id):
|
||||
pbar = ProgressBar(seconds, node_id=unique_id)
|
||||
start = time.time()
|
||||
expiration = start + seconds
|
||||
now = start
|
||||
while now < expiration:
|
||||
now = time.time()
|
||||
pbar.update_absolute(now - start)
|
||||
await asyncio.sleep(0.01)
|
||||
return (value,)
|
||||
|
||||
class TestParallelSleep(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE", ),
|
||||
"image2": ("IMAGE", ),
|
||||
"image3": ("IMAGE", ),
|
||||
"sleep1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"sleep2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"sleep3": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "parallel_sleep"
|
||||
CATEGORY = "_for_testing"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
|
||||
# Create a graph dynamically with three TestSleep nodes
|
||||
g = GraphBuilder()
|
||||
|
||||
# Create sleep nodes for each duration and image
|
||||
sleep_node1 = g.node("TestSleep", value=image1, seconds=sleep1)
|
||||
sleep_node2 = g.node("TestSleep", value=image2, seconds=sleep2)
|
||||
sleep_node3 = g.node("TestSleep", value=image3, seconds=sleep3)
|
||||
|
||||
# Blend the results using TestVariadicAverage
|
||||
blend = g.node("TestVariadicAverage",
|
||||
input1=sleep_node1.out(0),
|
||||
input2=sleep_node2.out(0),
|
||||
input3=sleep_node3.out(0))
|
||||
|
||||
return {
|
||||
"result": (blend.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestOutputNodeWithSocketOutput:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "_for_testing"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def process(self, image, value):
|
||||
# Apply value scaling and return both as output and socket
|
||||
result = image * value
|
||||
return (result,)
|
||||
|
||||
TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestLazyMixImages": TestLazyMixImages,
|
||||
"TestVariadicAverage": TestVariadicAverage,
|
||||
"TestCustomIsChanged": TestCustomIsChanged,
|
||||
"TestIsChangedWithConstants": TestIsChangedWithConstants,
|
||||
"TestCustomValidation1": TestCustomValidation1,
|
||||
"TestCustomValidation2": TestCustomValidation2,
|
||||
"TestCustomValidation3": TestCustomValidation3,
|
||||
"TestCustomValidation4": TestCustomValidation4,
|
||||
"TestCustomValidation5": TestCustomValidation5,
|
||||
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
|
||||
"TestMixedExpansionReturns": TestMixedExpansionReturns,
|
||||
"TestSamplingInExpansion": TestSamplingInExpansion,
|
||||
"TestSleep": TestSleep,
|
||||
"TestParallelSleep": TestParallelSleep,
|
||||
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
|
||||
}
|
||||
|
||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestLazyMixImages": "Lazy Mix Images",
|
||||
"TestVariadicAverage": "Variadic Average",
|
||||
"TestCustomIsChanged": "Custom IsChanged",
|
||||
"TestIsChangedWithConstants": "IsChanged With Constants",
|
||||
"TestCustomValidation1": "Custom Validation 1",
|
||||
"TestCustomValidation2": "Custom Validation 2",
|
||||
"TestCustomValidation3": "Custom Validation 3",
|
||||
"TestCustomValidation4": "Custom Validation 4",
|
||||
"TestCustomValidation5": "Custom Validation 5",
|
||||
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
|
||||
"TestMixedExpansionReturns": "Mixed Expansion Returns",
|
||||
"TestSamplingInExpansion": "Sampling In Expansion",
|
||||
"TestSleep": "Test Sleep",
|
||||
"TestParallelSleep": "Test Parallel Sleep",
|
||||
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
|
||||
}
|
129
tests/execution/testing_nodes/testing-pack/stubs.py
Normal file
129
tests/execution/testing_nodes/testing-pack/stubs.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
|
||||
class StubImage:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"content": (['WHITE', 'BLACK', 'NOISE'],),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stub_image"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_image(self, content, height, width, batch_size):
|
||||
if content == "WHITE":
|
||||
return (torch.ones(batch_size, height, width, 3),)
|
||||
elif content == "BLACK":
|
||||
return (torch.zeros(batch_size, height, width, 3),)
|
||||
elif content == "NOISE":
|
||||
return (torch.rand(batch_size, height, width, 3),)
|
||||
|
||||
class StubConstantImage:
|
||||
def __init__(self):
|
||||
pass
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stub_constant_image"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_constant_image(self, value, height, width, batch_size):
|
||||
return (torch.ones(batch_size, height, width, 3) * value,)
|
||||
|
||||
class StubMask:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "stub_mask"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_mask(self, value, height, width, batch_size):
|
||||
return (torch.ones(batch_size, height, width) * value,)
|
||||
|
||||
class StubInt:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("INT", {"default": 0, "min": -0xffffffff, "max": 0xffffffff, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "stub_int"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_int(self, value):
|
||||
return (value,)
|
||||
|
||||
class StubFloat:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 0.0, "min": -1.0e38, "max": 1.0e38, "step": 0.01}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "stub_float"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_float(self, value):
|
||||
return (value,)
|
||||
|
||||
TEST_STUB_NODE_CLASS_MAPPINGS = {
|
||||
"StubImage": StubImage,
|
||||
"StubConstantImage": StubConstantImage,
|
||||
"StubMask": StubMask,
|
||||
"StubInt": StubInt,
|
||||
"StubFloat": StubFloat,
|
||||
}
|
||||
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StubImage": "Stub Image",
|
||||
"StubConstantImage": "Stub Constant Image",
|
||||
"StubMask": "Stub Mask",
|
||||
"StubInt": "Stub Int",
|
||||
"StubFloat": "Stub Float",
|
||||
}
|
53
tests/execution/testing_nodes/testing-pack/tools.py
Normal file
53
tests/execution/testing_nodes/testing-pack/tools.py
Normal file
@@ -0,0 +1,53 @@
|
||||
|
||||
def MakeSmartType(t):
|
||||
if isinstance(t, str):
|
||||
return SmartType(t)
|
||||
return t
|
||||
|
||||
class SmartType(str):
|
||||
def __ne__(self, other):
|
||||
if self == "*" or other == "*":
|
||||
return False
|
||||
selfset = set(self.split(','))
|
||||
otherset = set(other.split(','))
|
||||
return not selfset.issubset(otherset)
|
||||
|
||||
def VariantSupport():
|
||||
def decorator(cls):
|
||||
if hasattr(cls, "INPUT_TYPES"):
|
||||
old_input_types = getattr(cls, "INPUT_TYPES")
|
||||
def new_input_types(*args, **kwargs):
|
||||
types = old_input_types(*args, **kwargs)
|
||||
for category in ["required", "optional"]:
|
||||
if category not in types:
|
||||
continue
|
||||
for key, value in types[category].items():
|
||||
if isinstance(value, tuple):
|
||||
types[category][key] = (MakeSmartType(value[0]),) + value[1:]
|
||||
return types
|
||||
setattr(cls, "INPUT_TYPES", new_input_types)
|
||||
if hasattr(cls, "RETURN_TYPES"):
|
||||
old_return_types = cls.RETURN_TYPES
|
||||
setattr(cls, "RETURN_TYPES", tuple(MakeSmartType(x) for x in old_return_types))
|
||||
if hasattr(cls, "VALIDATE_INPUTS"):
|
||||
# Reflection is used to determine what the function signature is, so we can't just change the function signature
|
||||
raise NotImplementedError("VariantSupport does not support VALIDATE_INPUTS yet")
|
||||
else:
|
||||
def validate_inputs(input_types):
|
||||
inputs = cls.INPUT_TYPES()
|
||||
for key, value in input_types.items():
|
||||
if isinstance(value, SmartType):
|
||||
continue
|
||||
if "required" in inputs and key in inputs["required"]:
|
||||
expected_type = inputs["required"][key][0]
|
||||
elif "optional" in inputs and key in inputs["optional"]:
|
||||
expected_type = inputs["optional"][key][0]
|
||||
else:
|
||||
expected_type = None
|
||||
if expected_type is not None and MakeSmartType(value) != expected_type:
|
||||
return f"Invalid type of {key}: {value} (expected {expected_type})"
|
||||
return True
|
||||
setattr(cls, "VALIDATE_INPUTS", validate_inputs)
|
||||
return cls
|
||||
return decorator
|
||||
|
364
tests/execution/testing_nodes/testing-pack/util.py
Normal file
364
tests/execution/testing_nodes/testing-pack/util.py
Normal file
@@ -0,0 +1,364 @@
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from .tools import VariantSupport
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulateNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"to_add": ("*",),
|
||||
},
|
||||
"optional": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ACCUMULATION",)
|
||||
FUNCTION = "accumulate"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def accumulate(self, to_add, accumulation = None):
|
||||
if accumulation is None:
|
||||
value = [to_add]
|
||||
else:
|
||||
value = accumulation["accum"] + [to_add]
|
||||
return ({"accum": value},)
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationHeadNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ACCUMULATION", "*",)
|
||||
FUNCTION = "accumulation_head"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def accumulation_head(self, accumulation):
|
||||
accum = accumulation["accum"]
|
||||
if len(accum) == 0:
|
||||
return (accumulation, None)
|
||||
else:
|
||||
return ({"accum": accum[1:]}, accum[0])
|
||||
|
||||
class TestAccumulationTailNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ACCUMULATION", "*",)
|
||||
FUNCTION = "accumulation_tail"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def accumulation_tail(self, accumulation):
|
||||
accum = accumulation["accum"]
|
||||
if len(accum) == 0:
|
||||
return (None, accumulation)
|
||||
else:
|
||||
return ({"accum": accum[:-1]}, accum[-1])
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationToListNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
OUTPUT_IS_LIST = (True,)
|
||||
|
||||
FUNCTION = "accumulation_to_list"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def accumulation_to_list(self, accumulation):
|
||||
return (accumulation["accum"],)
|
||||
|
||||
@VariantSupport()
|
||||
class TestListToAccumulationNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"list": ("*",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ACCUMULATION",)
|
||||
INPUT_IS_LIST = (True,)
|
||||
|
||||
FUNCTION = "list_to_accumulation"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def list_to_accumulation(self, list):
|
||||
return ({"accum": list},)
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationGetLengthNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT",)
|
||||
|
||||
FUNCTION = "accumlength"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def accumlength(self, accumulation):
|
||||
return (len(accumulation['accum']),)
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationGetItemNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
"index": ("INT", {"default":0, "step":1})
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
|
||||
FUNCTION = "get_item"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def get_item(self, accumulation, index):
|
||||
return (accumulation['accum'][index],)
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationSetItemNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
"index": ("INT", {"default":0, "step":1}),
|
||||
"value": ("*",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ACCUMULATION",)
|
||||
|
||||
FUNCTION = "set_item"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def set_item(self, accumulation, index, value):
|
||||
new_accum = accumulation['accum'][:]
|
||||
new_accum[index] = value
|
||||
return ({"accum": new_accum},)
|
||||
|
||||
class TestIntMathOperation:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
|
||||
"b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
|
||||
"operation": (["add", "subtract", "multiply", "divide", "modulo", "power"],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "int_math_operation"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def int_math_operation(self, a, b, operation):
|
||||
if operation == "add":
|
||||
return (a + b,)
|
||||
elif operation == "subtract":
|
||||
return (a - b,)
|
||||
elif operation == "multiply":
|
||||
return (a * b,)
|
||||
elif operation == "divide":
|
||||
return (a // b,)
|
||||
elif operation == "modulo":
|
||||
return (a % b,)
|
||||
elif operation == "power":
|
||||
return (a ** b,)
|
||||
|
||||
|
||||
from .flow_control import NUM_FLOW_SOCKETS
|
||||
@VariantSupport()
|
||||
class TestForLoopOpen:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"remaining": ("INT", {"default": 1, "min": 0, "max": 100000, "step": 1}),
|
||||
},
|
||||
"optional": {
|
||||
f"initial_value{i}": ("*",) for i in range(1, NUM_FLOW_SOCKETS)
|
||||
},
|
||||
"hidden": {
|
||||
"initial_value0": ("*",)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = tuple(["FLOW_CONTROL", "INT",] + ["*"] * (NUM_FLOW_SOCKETS-1))
|
||||
RETURN_NAMES = tuple(["flow_control", "remaining"] + [f"value{i}" for i in range(1, NUM_FLOW_SOCKETS)])
|
||||
FUNCTION = "for_loop_open"
|
||||
|
||||
CATEGORY = "Testing/Flow"
|
||||
|
||||
def for_loop_open(self, remaining, **kwargs):
|
||||
graph = GraphBuilder()
|
||||
if "initial_value0" in kwargs:
|
||||
remaining = kwargs["initial_value0"]
|
||||
graph.node("TestWhileLoopOpen", condition=remaining, initial_value0=remaining, **{(f"initial_value{i}"): kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)})
|
||||
outputs = [kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)]
|
||||
return {
|
||||
"result": tuple(["stub", remaining] + outputs),
|
||||
"expand": graph.finalize(),
|
||||
}
|
||||
|
||||
@VariantSupport()
|
||||
class TestForLoopClose:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"flow_control": ("FLOW_CONTROL", {"rawLink": True}),
|
||||
},
|
||||
"optional": {
|
||||
f"initial_value{i}": ("*",{"rawLink": True}) for i in range(1, NUM_FLOW_SOCKETS)
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = tuple(["*"] * (NUM_FLOW_SOCKETS-1))
|
||||
RETURN_NAMES = tuple([f"value{i}" for i in range(1, NUM_FLOW_SOCKETS)])
|
||||
FUNCTION = "for_loop_close"
|
||||
|
||||
CATEGORY = "Testing/Flow"
|
||||
|
||||
def for_loop_close(self, flow_control, **kwargs):
|
||||
graph = GraphBuilder()
|
||||
while_open = flow_control[0]
|
||||
sub = graph.node("TestIntMathOperation", operation="subtract", a=[while_open,1], b=1)
|
||||
cond = graph.node("TestToBoolNode", value=sub.out(0))
|
||||
input_values = {f"initial_value{i}": kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)}
|
||||
while_close = graph.node("TestWhileLoopClose",
|
||||
flow_control=flow_control,
|
||||
condition=cond.out(0),
|
||||
initial_value0=sub.out(0),
|
||||
**input_values)
|
||||
return {
|
||||
"result": tuple([while_close.out(i) for i in range(1, NUM_FLOW_SOCKETS)]),
|
||||
"expand": graph.finalize(),
|
||||
}
|
||||
|
||||
NUM_LIST_SOCKETS = 10
|
||||
@VariantSupport()
|
||||
class TestMakeListNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value1": ("*",),
|
||||
},
|
||||
"optional": {
|
||||
f"value{i}": ("*",) for i in range(1, NUM_LIST_SOCKETS)
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
FUNCTION = "make_list"
|
||||
OUTPUT_IS_LIST = (True,)
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def make_list(self, **kwargs):
|
||||
result = []
|
||||
for i in range(NUM_LIST_SOCKETS):
|
||||
if f"value{i}" in kwargs:
|
||||
result.append(kwargs[f"value{i}"])
|
||||
return (result,)
|
||||
|
||||
UTILITY_NODE_CLASS_MAPPINGS = {
|
||||
"TestAccumulateNode": TestAccumulateNode,
|
||||
"TestAccumulationHeadNode": TestAccumulationHeadNode,
|
||||
"TestAccumulationTailNode": TestAccumulationTailNode,
|
||||
"TestAccumulationToListNode": TestAccumulationToListNode,
|
||||
"TestListToAccumulationNode": TestListToAccumulationNode,
|
||||
"TestAccumulationGetLengthNode": TestAccumulationGetLengthNode,
|
||||
"TestAccumulationGetItemNode": TestAccumulationGetItemNode,
|
||||
"TestAccumulationSetItemNode": TestAccumulationSetItemNode,
|
||||
"TestForLoopOpen": TestForLoopOpen,
|
||||
"TestForLoopClose": TestForLoopClose,
|
||||
"TestIntMathOperation": TestIntMathOperation,
|
||||
"TestMakeListNode": TestMakeListNode,
|
||||
}
|
||||
UTILITY_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestAccumulateNode": "Accumulate",
|
||||
"TestAccumulationHeadNode": "Accumulation Head",
|
||||
"TestAccumulationTailNode": "Accumulation Tail",
|
||||
"TestAccumulationToListNode": "Accumulation to List",
|
||||
"TestListToAccumulationNode": "List to Accumulation",
|
||||
"TestAccumulationGetLengthNode": "Accumulation Get Length",
|
||||
"TestAccumulationGetItemNode": "Accumulation Get Item",
|
||||
"TestAccumulationSetItemNode": "Accumulation Set Item",
|
||||
"TestForLoopOpen": "For Loop Open",
|
||||
"TestForLoopClose": "For Loop Close",
|
||||
"TestIntMathOperation": "Int Math Operation",
|
||||
"TestMakeListNode": "Make List",
|
||||
}
|
Reference in New Issue
Block a user