Support for async execution functions

This commit adds support for node execution functions defined as async. When
a node's execution function is defined as async, we can continue
executing other nodes while it is processing.

Standard uses of `await` should "just work", but people will still have
to be careful if they spawn actual threads. Because torch doesn't really
have async/await versions of functions, this won't particularly help
with most locally-executing nodes, but it does work for e.g. web
requests to other machines.

In addition to the execute function, the `VALIDATE_INPUTS` and
`check_lazy_status` functions can also be defined as async, though we'll
only resolve one node at a time right now for those.
This commit is contained in:
Jacob Segal 2025-04-19 00:06:13 -07:00
parent 772de7c006
commit 46c8311d14
11 changed files with 745 additions and 67 deletions

View File

@ -997,11 +997,12 @@ def set_progress_bar_global_hook(function):
PROGRESS_BAR_HOOK = function PROGRESS_BAR_HOOK = function
class ProgressBar: class ProgressBar:
def __init__(self, total): def __init__(self, total, node_id=None):
global PROGRESS_BAR_HOOK global PROGRESS_BAR_HOOK
self.total = total self.total = total
self.current = 0 self.current = 0
self.hook = PROGRESS_BAR_HOOK self.hook = PROGRESS_BAR_HOOK
self.node_id = node_id
def update_absolute(self, value, total=None, preview=None): def update_absolute(self, value, total=None, preview=None):
if total is not None: if total is not None:
@ -1010,7 +1011,7 @@ class ProgressBar:
value = self.total value = self.total
self.current = value self.current = value
if self.hook is not None: if self.hook is not None:
self.hook(self.current, self.total, preview) self.hook(self.current, self.total, preview, node_id=self.node_id)
def update(self, value): def update(self, value):
self.update_absolute(self.current + value) self.update_absolute(self.current + value)

View File

@ -1,6 +1,7 @@
import itertools import itertools
from typing import Sequence, Mapping, Dict from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt from comfy_execution.graph import DynamicPrompt
from abc import ABC, abstractmethod
import nodes import nodes
@ -16,12 +17,13 @@ def include_unique_id_in_input(class_type: str) -> bool:
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values() NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet: class CacheKeySet(ABC):
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed_cache):
self.keys = {} self.keys = {}
self.subcache_keys = {} self.subcache_keys = {}
def add_keys(self, node_ids): @abstractmethod
async def add_keys(self, node_ids):
raise NotImplementedError() raise NotImplementedError()
def all_node_ids(self): def all_node_ids(self):
@ -60,9 +62,8 @@ class CacheKeySetID(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache) super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.add_keys(node_ids)
def add_keys(self, node_ids): async def add_keys(self, node_ids):
for node_id in node_ids: for node_id in node_ids:
if node_id in self.keys: if node_id in self.keys:
continue continue
@ -77,37 +78,36 @@ class CacheKeySetInputSignature(CacheKeySet):
super().__init__(dynprompt, node_ids, is_changed_cache) super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.is_changed_cache = is_changed_cache self.is_changed_cache = is_changed_cache
self.add_keys(node_ids)
def include_node_id_in_input(self) -> bool: def include_node_id_in_input(self) -> bool:
return False return False
def add_keys(self, node_ids): async def add_keys(self, node_ids):
for node_id in node_ids: for node_id in node_ids:
if node_id in self.keys: if node_id in self.keys:
continue continue
if not self.dynprompt.has_node(node_id): if not self.dynprompt.has_node(node_id):
continue continue
node = self.dynprompt.get_node(node_id) node = self.dynprompt.get_node(node_id)
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id) self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"])
def get_node_signature(self, dynprompt, node_id): async def get_node_signature(self, dynprompt, node_id):
signature = [] signature = []
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
for ancestor_id in ancestors: for ancestor_id in ancestors:
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
return to_hashable(signature) return to_hashable(signature)
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
if not dynprompt.has_node(node_id): if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it. # This node doesn't exist -- we can't cache it.
return [float("NaN")] return [float("NaN")]
node = dynprompt.get_node(node_id) node = dynprompt.get_node(node_id)
class_type = node["class_type"] class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)] signature = [class_type, await self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id) signature.append(node_id)
inputs = node["inputs"] inputs = node["inputs"]
@ -150,9 +150,10 @@ class BasicCache:
self.cache = {} self.cache = {}
self.subcaches = {} self.subcaches = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache): async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
await self.cache_key_set.add_keys(node_ids)
self.is_changed_cache = is_changed_cache self.is_changed_cache = is_changed_cache
self.initialized = True self.initialized = True
@ -201,13 +202,13 @@ class BasicCache:
else: else:
return None return None
def _ensure_subcache(self, node_id, children_ids): async def _ensure_subcache(self, node_id, children_ids):
subcache_key = self.cache_key_set.get_subcache_key(node_id) subcache_key = self.cache_key_set.get_subcache_key(node_id)
subcache = self.subcaches.get(subcache_key, None) subcache = self.subcaches.get(subcache_key, None)
if subcache is None: if subcache is None:
subcache = BasicCache(self.key_class) subcache = BasicCache(self.key_class)
self.subcaches[subcache_key] = subcache self.subcaches[subcache_key] = subcache
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
return subcache return subcache
def _get_subcache(self, node_id): def _get_subcache(self, node_id):
@ -259,10 +260,10 @@ class HierarchicalCache(BasicCache):
assert cache is not None assert cache is not None
cache._set_immediate(node_id, value) cache._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id) cache = self._get_cache_for(node_id)
assert cache is not None assert cache is not None
return cache._ensure_subcache(node_id, children_ids) return await cache._ensure_subcache(node_id, children_ids)
class LRUCache(BasicCache): class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100): def __init__(self, key_class, max_size=100):
@ -273,8 +274,8 @@ class LRUCache(BasicCache):
self.used_generation = {} self.used_generation = {}
self.children = {} self.children = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache): async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
super().set_prompt(dynprompt, node_ids, is_changed_cache) await super().set_prompt(dynprompt, node_ids, is_changed_cache)
self.generation += 1 self.generation += 1
for node_id in node_ids: for node_id in node_ids:
self._mark_used(node_id) self._mark_used(node_id)
@ -303,11 +304,11 @@ class LRUCache(BasicCache):
self._mark_used(node_id) self._mark_used(node_id)
return self._set_immediate(node_id, value) return self._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes # Just uses subcaches for tracking 'live' nodes
super()._ensure_subcache(node_id, children_ids) await super()._ensure_subcache(node_id, children_ids)
self.cache_key_set.add_keys(children_ids) await self.cache_key_set.add_keys(children_ids)
self._mark_used(node_id) self._mark_used(node_id)
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
self.children[cache_key] = [] self.children[cache_key] = []
@ -337,7 +338,7 @@ class DependencyAwareCache(BasicCache):
self.ancestors = {} # Maps node_id -> set of ancestor node_ids self.ancestors = {} # Maps node_id -> set of ancestor node_ids
self.executed_nodes = set() # Tracks nodes that have been executed self.executed_nodes = set() # Tracks nodes that have been executed
def set_prompt(self, dynprompt, node_ids, is_changed_cache): async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
""" """
Clear the entire cache and rebuild the dependency graph. Clear the entire cache and rebuild the dependency graph.
@ -354,7 +355,7 @@ class DependencyAwareCache(BasicCache):
self.executed_nodes.clear() self.executed_nodes.clear()
# Call the parent method to initialize the cache with the new prompt # Call the parent method to initialize the cache with the new prompt
super().set_prompt(dynprompt, node_ids, is_changed_cache) await super().set_prompt(dynprompt, node_ids, is_changed_cache)
# Rebuild the dependency graph # Rebuild the dependency graph
self._build_dependency_graph(dynprompt, node_ids) self._build_dependency_graph(dynprompt, node_ids)
@ -405,7 +406,7 @@ class DependencyAwareCache(BasicCache):
""" """
return self._get_immediate(node_id) return self._get_immediate(node_id)
def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
""" """
Ensure a subcache exists for a node and update dependencies. Ensure a subcache exists for a node and update dependencies.
@ -416,7 +417,7 @@ class DependencyAwareCache(BasicCache):
Returns: Returns:
The subcache object for the node. The subcache object for the node.
""" """
subcache = super()._ensure_subcache(node_id, children_ids) subcache = await super()._ensure_subcache(node_id, children_ids)
for child_id in children_ids: for child_id in children_ids:
self.descendants[node_id].add(child_id) self.descendants[node_id].add(child_id)
self.ancestors[child_id].add(node_id) self.ancestors[child_id].add(node_id)

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Type, Literal from typing import Type, Literal
import nodes import nodes
import asyncio
from comfy_execution.graph_utils import is_link from comfy_execution.graph_utils import is_link
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
@ -100,6 +101,8 @@ class TopologicalSort:
self.pendingNodes = {} self.pendingNodes = {}
self.blockCount = {} # Number of nodes this node is directly blocked by self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node self.blocking = {} # Which nodes are blocked by this node
self.externalBlocks = 0
self.unblockedEvent = asyncio.Event()
def get_input_info(self, unique_id, input_name): def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"] class_type = self.dynprompt.get_node(unique_id)["class_type"]
@ -153,6 +156,16 @@ class TopologicalSort:
for link in links: for link in links:
self.add_strong_link(*link) self.add_strong_link(*link)
def add_external_block(self, node_id):
assert node_id in self.blockCount, "Can't add external block to a node that isn't pending"
self.externalBlocks += 1
self.blockCount[node_id] += 1
def unblock():
self.externalBlocks -= 1
self.blockCount[node_id] -= 1
self.unblockedEvent.set()
return unblock
def is_cached(self, node_id): def is_cached(self, node_id):
return False return False
@ -181,11 +194,16 @@ class ExecutionList(TopologicalSort):
def is_cached(self, node_id): def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None return self.output_cache.get(node_id) is not None
def stage_node_execution(self): async def stage_node_execution(self):
assert self.staged_node_id is None assert self.staged_node_id is None
if self.is_empty(): if self.is_empty():
return None, None, None return None, None, None
available = self.get_ready_nodes() available = self.get_ready_nodes()
while len(available) == 0 and self.externalBlocks > 0:
# Wait for an external block to be released
await self.unblockedEvent.wait()
self.unblockedEvent.clear()
available = self.get_ready_nodes()
if len(available) == 0: if len(available) == 0:
cycled_nodes = self.get_nodes_in_cycle() cycled_nodes = self.get_nodes_in_cycle()
# Because cycles composed entirely of static nodes are caught during initial validation, # Because cycles composed entirely of static nodes are caught during initial validation,

288
comfy_execution/progress.py Normal file
View File

@ -0,0 +1,288 @@
from typing import TypedDict, Dict, Optional
from typing_extensions import override
from PIL import Image
from enum import Enum
from abc import ABC
from tqdm import tqdm
from comfy_execution.graph import DynamicPrompt
from protocol import BinaryEventTypes
class NodeState(Enum):
Pending = "pending"
Running = "running"
Finished = "finished"
Error = "error"
class NodeProgressState(TypedDict):
"""
A class to represent the state of a node's progress.
"""
state: NodeState
value: float
max: float
class ProgressHandler(ABC):
"""
Abstract base class for progress handlers.
Progress handlers receive progress updates and display them in various ways.
"""
def __init__(self, name: str):
self.name = name
self.enabled = True
def set_registry(self, registry: "ProgressRegistry"):
pass
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
"""Called when a node starts processing"""
pass
def update_handler(self, node_id: str, value: float, max_value: float,
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
"""Called when a node's progress is updated"""
pass
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
"""Called when a node finishes processing"""
pass
def reset(self):
"""Called when the progress registry is reset"""
pass
def enable(self):
"""Enable this handler"""
self.enabled = True
def disable(self):
"""Disable this handler"""
self.enabled = False
class CLIProgressHandler(ProgressHandler):
"""
Handler that displays progress using tqdm progress bars in the CLI.
"""
def __init__(self):
super().__init__("cli")
self.progress_bars: Dict[str, tqdm] = {}
@override
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
# Create a new tqdm progress bar
if node_id not in self.progress_bars:
self.progress_bars[node_id] = tqdm(
total=state["max"],
desc=f"Node {node_id}",
unit="steps",
leave=True,
position=len(self.progress_bars)
)
@override
def update_handler(self, node_id: str, value: float, max_value: float,
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
# Handle case where start_handler wasn't called
if node_id not in self.progress_bars:
self.progress_bars[node_id] = tqdm(
total=max_value,
desc=f"Node {node_id}",
unit="steps",
leave=True,
position=len(self.progress_bars)
)
self.progress_bars[node_id].update(value)
else:
# Update existing progress bar
if max_value != self.progress_bars[node_id].total:
self.progress_bars[node_id].total = max_value
# Calculate the update amount (difference from current position)
current_position = self.progress_bars[node_id].n
update_amount = value - current_position
if update_amount > 0:
self.progress_bars[node_id].update(update_amount)
@override
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
# Complete and close the progress bar if it exists
if node_id in self.progress_bars:
# Ensure the bar shows 100% completion
remaining = state["max"] - self.progress_bars[node_id].n
if remaining > 0:
self.progress_bars[node_id].update(remaining)
self.progress_bars[node_id].close()
del self.progress_bars[node_id]
@override
def reset(self):
# Close all progress bars
for bar in self.progress_bars.values():
bar.close()
self.progress_bars.clear()
class WebUIProgressHandler(ProgressHandler):
"""
Handler that sends progress updates to the WebUI via WebSockets.
"""
def __init__(self, server_instance):
super().__init__("webui")
self.server_instance = server_instance
def set_registry(self, registry: "ProgressRegistry"):
self.registry = registry
def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]):
"""Send the current progress state to the client"""
if self.server_instance is None:
return
# Only send info for non-pending nodes
active_nodes = {
node_id: {
"value": state["value"],
"max": state["max"],
"state": state["state"].value,
"node_id": node_id,
"prompt_id": prompt_id,
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id)
}
for node_id, state in nodes.items()
if state["state"] != NodeState.Pending
}
# Send a combined progress_state message with all node states
self.server_instance.send_sync("progress_state", {
"prompt_id": prompt_id,
"nodes": active_nodes
})
@override
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
# Send progress state of all nodes
if self.registry:
self._send_progress_state(prompt_id, self.registry.nodes)
@override
def update_handler(self, node_id: str, value: float, max_value: float,
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
# Send progress state of all nodes
if self.registry:
self._send_progress_state(prompt_id, self.registry.nodes)
if image:
metadata = {
"node_id": node_id,
"prompt_id": prompt_id,
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id)
}
self.server_instance.send_sync(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, (image, metadata), self.server_instance.client_id)
@override
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
# Send progress state of all nodes
if self.registry:
self._send_progress_state(prompt_id, self.registry.nodes)
class ProgressRegistry:
"""
Registry that maintains node progress state and notifies registered handlers.
"""
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt):
self.prompt_id = prompt_id
self.dynprompt = dynprompt
self.nodes: Dict[str, NodeProgressState] = {}
self.handlers: Dict[str, ProgressHandler] = {}
def register_handler(self, handler: ProgressHandler) -> None:
"""Register a progress handler"""
self.handlers[handler.name] = handler
def unregister_handler(self, handler_name: str) -> None:
"""Unregister a progress handler"""
if handler_name in self.handlers:
# Allow handler to clean up resources
self.handlers[handler_name].reset()
del self.handlers[handler_name]
def enable_handler(self, handler_name: str) -> None:
"""Enable a progress handler"""
if handler_name in self.handlers:
self.handlers[handler_name].enable()
def disable_handler(self, handler_name: str) -> None:
"""Disable a progress handler"""
if handler_name in self.handlers:
self.handlers[handler_name].disable()
def ensure_entry(self, node_id: str) -> NodeProgressState:
"""Ensure a node entry exists"""
if node_id not in self.nodes:
self.nodes[node_id] = NodeProgressState(
state = NodeState.Pending,
value = 0,
max = 1
)
return self.nodes[node_id]
def start_progress(self, node_id: str) -> None:
"""Start progress tracking for a node"""
entry = self.ensure_entry(node_id)
entry["state"] = NodeState.Running
entry["value"] = 0.0
entry["max"] = 1.0
# Notify all enabled handlers
for handler in self.handlers.values():
if handler.enabled:
handler.start_handler(node_id, entry, self.prompt_id)
def update_progress(self, node_id: str, value: float, max_value: float, image: Optional[Image.Image]) -> None:
"""Update progress for a node"""
entry = self.ensure_entry(node_id)
entry["state"] = NodeState.Running
entry["value"] = value
entry["max"] = max_value
# Notify all enabled handlers
for handler in self.handlers.values():
if handler.enabled:
handler.update_handler(node_id, value, max_value, entry, self.prompt_id, image)
def finish_progress(self, node_id: str) -> None:
"""Finish progress tracking for a node"""
entry = self.ensure_entry(node_id)
entry["state"] = NodeState.Finished
entry["value"] = entry["max"]
# Notify all enabled handlers
for handler in self.handlers.values():
if handler.enabled:
handler.finish_handler(node_id, entry, self.prompt_id)
def reset_handlers(self) -> None:
"""Reset all handlers"""
for handler in self.handlers.values():
handler.reset()
# Global registry instance
global_progress_registry: ProgressRegistry = ProgressRegistry(prompt_id="", dynprompt=DynamicPrompt({}))
def reset_progress_state(prompt_id: str, dynprompt: DynamicPrompt) -> None:
global global_progress_registry
# Reset existing handlers if registry exists
if global_progress_registry is not None:
global_progress_registry.reset_handlers()
# Create new registry
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
def add_progress_handler(handler: ProgressHandler) -> None:
handler.set_registry(global_progress_registry)
global_progress_registry.register_handler(handler)
def get_progress_state() -> ProgressRegistry:
return global_progress_registry

46
comfy_execution/utils.py Normal file
View File

@ -0,0 +1,46 @@
import contextvars
from typing import Optional, NamedTuple
class ExecutionContext(NamedTuple):
"""
Context information about the currently executing node.
Attributes:
node_id: The ID of the currently executing node
list_index: The index in a list being processed (for operations on batches/lists)
"""
prompt_id: str
node_id: str
list_index: Optional[int]
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
def get_executing_context() -> Optional[ExecutionContext]:
return current_executing_context.get(None)
class CurrentNodeContext:
"""
Context manager for setting the current executing node context.
Sets the current_executing_context on enter and resets it on exit.
Example:
with CurrentNodeContext(node_id="123", list_index=0):
# Code that should run with the current node context set
process_image()
"""
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
self.context = ExecutionContext(
prompt_id= prompt_id,
node_id= node_id,
list_index= list_index
)
self.token = None
def __enter__(self):
self.token = current_executing_context.set(self.context)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.token is not None:
current_executing_context.reset(self.token)

View File

@ -8,12 +8,14 @@ import time
import traceback import traceback
from enum import Enum from enum import Enum
from typing import List, Literal, NamedTuple, Optional from typing import List, Literal, NamedTuple, Optional
import asyncio
import torch import torch
import comfy.model_management import comfy.model_management
import nodes import nodes
from comfy_execution.caching import ( from comfy_execution.caching import (
BasicCache,
CacheKeySetID, CacheKeySetID,
CacheKeySetInputSignature, CacheKeySetInputSignature,
DependencyAwareCache, DependencyAwareCache,
@ -28,6 +30,8 @@ from comfy_execution.graph import (
) )
from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input from comfy_execution.validation import validate_node_input
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.utils import CurrentNodeContext
class ExecutionResult(Enum): class ExecutionResult(Enum):
@ -39,12 +43,13 @@ class DuplicateNodeError(Exception):
pass pass
class IsChangedCache: class IsChangedCache:
def __init__(self, dynprompt, outputs_cache): def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache):
self.prompt_id = prompt_id
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.outputs_cache = outputs_cache self.outputs_cache = outputs_cache
self.is_changed = {} self.is_changed = {}
def get(self, node_id): async def get(self, node_id):
if node_id in self.is_changed: if node_id in self.is_changed:
return self.is_changed[node_id] return self.is_changed[node_id]
@ -62,7 +67,8 @@ class IsChangedCache:
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None) input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
try: try:
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED") is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED")
is_changed = await resolve_map_node_over_list_results(is_changed)
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e: except Exception as e:
logging.warning("WARNING: {}".format(e)) logging.warning("WARNING: {}".format(e))
@ -164,7 +170,19 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
map_node_over_list = None #Don't hook this please map_node_over_list = None #Don't hook this please
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): async def resolve_map_node_over_list_results(results):
remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()]
if len(remaining) == 0:
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
else:
done, pending = await asyncio.wait(remaining)
for task in done:
exc = task.exception()
if exc is not None:
raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# check if node wants the lists # check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False) input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@ -178,7 +196,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
return {k: v[i if len(v) > i else -1] for k, v in d.items()} return {k: v[i if len(v) > i else -1] for k, v in d.items()}
results = [] results = []
def process_inputs(inputs, index=None, input_is_list=False): async def process_inputs(inputs, index=None, input_is_list=False):
if allow_interrupt: if allow_interrupt:
nodes.before_node_execution() nodes.before_node_execution()
execution_block = None execution_block = None
@ -194,20 +212,37 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
if execution_block is None: if execution_block is None:
if pre_execute_cb is not None and index is not None: if pre_execute_cb is not None and index is not None:
pre_execute_cb(index) pre_execute_cb(index)
results.append(getattr(obj, func)(**inputs)) f = getattr(obj, func)
if inspect.iscoroutinefunction(f):
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
with CurrentNodeContext(prompt_id, unique_id, list_index):
return await f(**args)
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
# Give the task a chance to execute without yielding
await asyncio.sleep(0)
if task.done():
result = task.result()
results.append(result)
else:
results.append(task)
else:
with CurrentNodeContext(prompt_id, unique_id, index):
result = f(**inputs)
results.append(result)
else: else:
results.append(execution_block) results.append(execution_block)
if input_is_list: if input_is_list:
process_inputs(input_data_all, 0, input_is_list=input_is_list) await process_inputs(input_data_all, 0, input_is_list=input_is_list)
elif max_len_input == 0: elif max_len_input == 0:
process_inputs({}) await process_inputs({})
else: else:
for i in range(max_len_input): for i in range(max_len_input):
input_dict = slice_dict(input_data_all, i) input_dict = slice_dict(input_data_all, i)
process_inputs(input_dict, i) await process_inputs(input_dict, i)
return results return results
def merge_result_data(results, obj): def merge_result_data(results, obj):
# check which outputs need concatenating # check which outputs need concatenating
output = [] output = []
@ -229,11 +264,18 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results]) output.append([o[i] for o in results])
return output return output
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task:
return return_values, {}, False, has_pending_task
output, ui, has_subgraph = get_output_from_returns(return_values, obj)
return output, ui, has_subgraph, False
def get_output_from_returns(return_values, obj):
results = [] results = []
uis = [] uis = []
subgraph_results = [] subgraph_results = []
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
has_subgraph = False has_subgraph = False
for i in range(len(return_values)): for i in range(len(return_values)):
r = return_values[i] r = return_values[i]
@ -267,6 +309,10 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
else: else:
output = [] output = []
ui = dict() ui = dict()
# TODO: Think there's an existing bug here
# If we're performing a subgraph expansion, we probably shouldn't be returning UI values yet.
# They'll get cached without the completed subgraphs. It's an edge case and I'm not aware of
# any nodes that use both subgraph expansion and custom UI outputs, but might be a problem in the future.
if len(uis) > 0: if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui, has_subgraph return output, ui, has_subgraph
@ -279,7 +325,7 @@ def format_value(x):
else: else:
return str(x) return str(x)
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results): async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
unique_id = current_item unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id) real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id)
@ -291,11 +337,16 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
if server.client_id is not None: if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {} cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
get_progress_state().finish_progress(unique_id)
return (ExecutionResult.SUCCESS, None, None) return (ExecutionResult.SUCCESS, None, None)
input_data_all = None input_data_all = None
try: try:
if unique_id in pending_subgraph_results: if unique_id in pending_async_nodes:
results = [r.result() if isinstance(r, asyncio.Task) else r for r in pending_async_nodes[unique_id]]
del pending_async_nodes[unique_id]
output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def)
elif unique_id in pending_subgraph_results:
cached_results = pending_subgraph_results[unique_id] cached_results = pending_subgraph_results[unique_id]
resolved_outputs = [] resolved_outputs = []
for is_subgraph, result in cached_results: for is_subgraph, result in cached_results:
@ -317,6 +368,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
output_ui = [] output_ui = []
has_subgraph = False has_subgraph = False
else: else:
get_progress_state().start_progress(unique_id)
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
if server.client_id is not None: if server.client_id is not None:
server.last_node_id = display_node_id server.last_node_id = display_node_id
@ -328,7 +380,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
caches.objects.set(unique_id, obj) caches.objects.set(unique_id, obj)
if hasattr(obj, "check_lazy_status"): if hasattr(obj, "check_lazy_status"):
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True) required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and ( required_inputs = [x for x in required_inputs if isinstance(x,str) and (
x not in input_data_all or x in missing_keys x not in input_data_all or x in missing_keys
@ -357,8 +410,18 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
else: else:
return block return block
def pre_execute_cb(call_index): def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0) GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id)
async def await_completion():
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
await asyncio.gather(*tasks)
unblock()
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0: if len(output_ui) > 0:
caches.ui.set(unique_id, { caches.ui.set(unique_id, {
"meta": { "meta": {
@ -401,7 +464,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
cached_outputs.append((True, node_outputs)) cached_outputs.append((True, node_outputs))
new_node_ids = set(new_node_ids) new_node_ids = set(new_node_ids)
for cache in caches.all: for cache in caches.all:
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused() subcache = await cache.ensure_subcache_for(unique_id, new_node_ids)
subcache.clean_unused()
for node_id in new_output_ids: for node_id in new_output_ids:
execution_list.add_node(node_id) execution_list.add_node(node_id)
for link in new_output_links: for link in new_output_links:
@ -446,6 +510,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
return (ExecutionResult.FAILURE, error_details, ex) return (ExecutionResult.FAILURE, error_details, ex)
get_progress_state().finish_progress(unique_id)
executed.add(unique_id) executed.add(unique_id)
return (ExecutionResult.SUCCESS, None, None) return (ExecutionResult.SUCCESS, None, None)
@ -500,6 +565,11 @@ class PromptExecutor:
self.add_message("execution_error", mes, broadcast=False) self.add_message("execution_error", mes, broadcast=False)
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False) nodes.interrupt_processing(False)
if "client_id" in extra_data: if "client_id" in extra_data:
@ -512,9 +582,11 @@ class PromptExecutor:
with torch.inference_mode(): with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt) dynamic_prompt = DynamicPrompt(prompt)
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all: for cache in self.caches.all:
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused() cache.clean_unused()
cached_nodes = [] cached_nodes = []
@ -527,6 +599,7 @@ class PromptExecutor:
{ "nodes": cached_nodes, "prompt_id": prompt_id}, { "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False) broadcast=False)
pending_subgraph_results = {} pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
executed = set() executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids() current_outputs = self.caches.outputs.all_node_ids()
@ -534,12 +607,13 @@ class PromptExecutor:
execution_list.add_node(node_id) execution_list.add_node(node_id)
while not execution_list.is_empty(): while not execution_list.is_empty():
node_id, error, ex = execution_list.stage_node_execution() node_id, error, ex = await execution_list.stage_node_execution()
if error is not None: if error is not None:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break break
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
self.success = result != ExecutionResult.FAILURE self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE: if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
@ -569,7 +643,7 @@ class PromptExecutor:
comfy.model_management.unload_all_models() comfy.model_management.unload_all_models()
def validate_inputs(prompt, item, validated): async def validate_inputs(prompt_id, prompt, item, validated):
unique_id = item unique_id = item
if unique_id in validated: if unique_id in validated:
return validated[unique_id] return validated[unique_id]
@ -646,7 +720,7 @@ def validate_inputs(prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
try: try:
r = validate_inputs(prompt, o_id, validated) r = await validate_inputs(prompt_id, prompt, o_id, validated)
if r[0] is False: if r[0] is False:
# `r` will be set in `validated[o_id]` already # `r` will be set in `validated[o_id]` already
valid = False valid = False
@ -771,7 +845,8 @@ def validate_inputs(prompt, item, validated):
input_filtered['input_types'] = [received_types] input_filtered['input_types'] = [received_types]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered) #ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
ret = await resolve_map_node_over_list_results(ret)
for x in input_filtered: for x in input_filtered:
for i, r in enumerate(ret): for i, r in enumerate(ret):
if r is not True and not isinstance(r, ExecutionBlocker): if r is not True and not isinstance(r, ExecutionBlocker):
@ -804,7 +879,7 @@ def full_type_name(klass):
return klass.__qualname__ return klass.__qualname__
return module + '.' + klass.__qualname__ return module + '.' + klass.__qualname__
def validate_prompt(prompt): async def validate_prompt(prompt_id, prompt):
outputs = set() outputs = set()
for x in prompt: for x in prompt:
if 'class_type' not in prompt[x]: if 'class_type' not in prompt[x]:
@ -847,7 +922,7 @@ def validate_prompt(prompt):
valid = False valid = False
reasons = [] reasons = []
try: try:
m = validate_inputs(prompt, o, validated) m = await validate_inputs(prompt_id, prompt, o, validated)
valid = m[0] valid = m[0]
reasons = m[1] reasons = m[1]
except Exception as ex: except Exception as ex:

21
main.py
View File

@ -11,6 +11,8 @@ import itertools
import utils.extra_config import utils.extra_config
import logging import logging
import sys import sys
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
if __name__ == "__main__": if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes. #NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
@ -131,7 +133,7 @@ import comfy.utils
import execution import execution
import server import server
from server import BinaryEventTypes from protocol import BinaryEventTypes
import nodes import nodes
import comfy.model_management import comfy.model_management
import comfyui_version import comfyui_version
@ -227,14 +229,25 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop() server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
) )
def hijack_progress(server_instance): def hijack_progress(server_instance):
def hook(value, total, preview_image): def hook(value, total, preview_image, prompt_id=None, node_id=None):
executing_context = get_executing_context()
if prompt_id is None and executing_context is not None:
prompt_id = executing_context.prompt_id
if node_id is None and executing_context is not None:
node_id = executing_context.node_id
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id} if prompt_id is None:
prompt_id = server_instance.last_prompt_id
if node_id is None:
node_id = server_instance.last_node_id
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
get_progress_state().update_progress(node_id, value, total, preview_image)
server_instance.send_sync("progress", progress, server_instance.client_id) server_instance.send_sync("progress", progress, server_instance.client_id)
if preview_image is not None: if preview_image is not None:
# Also send old method for backward compatibility
# TODO - Remove after this repo is updated to frontend with metadata support
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id) server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
comfy.utils.set_progress_bar_global_hook(hook) comfy.utils.set_progress_bar_global_hook(hook)

View File

@ -35,11 +35,7 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager from app.custom_node_manager import CustomNodeManager
from typing import Optional, Union from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
TEXT = 3
async def send_socket_catch_exception(function, message): async def send_socket_catch_exception(function, message):
try: try:
@ -643,7 +639,8 @@ class PromptServer():
if "prompt" in json_data: if "prompt" in json_data:
prompt = json_data["prompt"] prompt = json_data["prompt"]
valid = execution.validate_prompt(prompt) prompt_id = str(uuid.uuid4())
valid = await execution.validate_prompt(prompt_id, prompt)
extra_data = {} extra_data = {}
if "extra_data" in json_data: if "extra_data" in json_data:
extra_data = json_data["extra_data"] extra_data = json_data["extra_data"]
@ -651,7 +648,6 @@ class PromptServer():
if "client_id" in json_data: if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"] extra_data["client_id"] = json_data["client_id"]
if valid[0]: if valid[0]:
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2] outputs_to_execute = valid[2]
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
@ -766,6 +762,10 @@ class PromptServer():
async def send(self, event, data, sid=None): async def send(self, event, data, sid=None):
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
await self.send_image(data, sid=sid) await self.send_image(data, sid=sid)
elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
# data is (preview_image, metadata)
preview_image, metadata = data
await self.send_image_with_metadata(preview_image, metadata, sid=sid)
elif isinstance(data, (bytes, bytearray)): elif isinstance(data, (bytes, bytearray)):
await self.send_bytes(event, data, sid) await self.send_bytes(event, data, sid)
else: else:
@ -804,6 +804,43 @@ class PromptServer():
preview_bytes = bytesIO.getvalue() preview_bytes = bytesIO.getvalue()
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_image_with_metadata(self, image_data, metadata=None, sid=None):
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.Resampling.LANCZOS
image = ImageOps.contain(image, (max_size, max_size), resampling)
mimetype = "image/png" if image_type == "PNG" else "image/jpeg"
# Prepare metadata
if metadata is None:
metadata = {}
metadata["image_type"] = mimetype
# Serialize metadata as JSON
import json
metadata_json = json.dumps(metadata).encode('utf-8')
metadata_length = len(metadata_json)
# Prepare image data
bytesIO = BytesIO()
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
image_bytes = bytesIO.getvalue()
# Combine metadata and image
combined_data = bytearray()
combined_data.extend(struct.pack(">I", metadata_length))
combined_data.extend(metadata_json)
combined_data.extend(image_bytes)
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, combined_data, sid=sid)
async def send_bytes(self, event, data, sid=None): async def send_bytes(self, event, data, sid=None):
message = self.encode_bytes(event, data) message = self.encode_bytes(event, data)

View File

@ -1,4 +1,4 @@
# Config for testing nodes # Config for testing nodes
testing: testing:
custom_nodes: tests/inference/testing_nodes custom_nodes: testing_nodes

View File

@ -252,7 +252,7 @@ class TestExecution:
@pytest.mark.parametrize("test_type, test_value", [ @pytest.mark.parametrize("test_type, test_value", [
("StubInt", 5), ("StubInt", 5),
("StubFloat", 5.0) ("StubMask", 5.0)
]) ])
def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder): def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
@ -497,6 +497,69 @@ class TestExecution:
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
assert not result.did_run(test_node), "The execution should have been cached" assert not result.did_run(test_node), "The execution should have been cached"
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create sleep nodes for each duration
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8)
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9)
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
# Add outputs to verify the execution
output1 = g.node("PreviewImage", images=sleep_node1.out(0))
output2 = g.node("PreviewImage", images=sleep_node2.out(0))
output3 = g.node("PreviewImage", images=sleep_node3.out(0))
start_time = time.time()
result = client.run(g)
elapsed_time = time.time() - start_time
# The test should take around 0.4 seconds (the longest sleep duration)
# plus some overhead, but definitely less than the sum of all sleeps (0.9s)
# We'll allow for up to 0.8s total to account for overhead
assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s"
# Verify that all nodes executed
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
assert result.did_run(sleep_node2), "Sleep node 2 should have run"
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
g = builder
# Create input images with different values
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
image3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Create a TestParallelSleep node that expands into multiple TestSleep nodes
parallel_sleep = g.node("TestParallelSleep",
image1=image1.out(0),
image2=image2.out(0),
image3=image3.out(0),
sleep1=0.4,
sleep2=0.5,
sleep3=0.6)
output = g.node("SaveImage", images=parallel_sleep.out(0))
start_time = time.time()
result = client.run(g)
elapsed_time = time.time() - start_time
# Similar to the previous test, expect parallel execution of the sleep nodes
# which should complete in less than the sum of all sleeps
assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s"
# Verify the parallel sleep node executed
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
# Verify we get an image as output (blend of the three input images)
result_images = result.get_images(output)
assert len(result_images) == 1, "Should have 1 image"
# Average pixel value should be around 170 (255 * 2 // 3)
avg_value = numpy.array(result_images[0]).mean()
assert avg_value == 170, f"Image average value {avg_value} should be 170"
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker # This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node, # as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
# only that one entry in the list is blocked. # only that one entry in the list is blocked.

View File

@ -1,6 +1,11 @@
import torch import torch
import time
import asyncio
from comfy.utils import ProgressBar
from .tools import VariantSupport from .tools import VariantSupport
from comfy_execution.graph_utils import GraphBuilder from comfy_execution.graph_utils import GraphBuilder
from comfy.comfy_types.node_typing import ComfyNodeABC
from comfy.comfy_types import IO
class TestLazyMixImages: class TestLazyMixImages:
@classmethod @classmethod
@ -333,6 +338,131 @@ class TestMixedExpansionReturns:
"expand": g.finalize(), "expand": g.finalize(),
} }
class TestSamplingInExpansion:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"clip": ("CLIP",),
"vae": ("VAE",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 100}),
"cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 30.0}),
"prompt": ("STRING", {"multiline": True, "default": "a beautiful landscape with mountains and trees"}),
"negative_prompt": ("STRING", {"multiline": True, "default": "blurry, bad quality, worst quality"}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "sampling_in_expansion"
CATEGORY = "Testing/Nodes"
def sampling_in_expansion(self, model, clip, vae, seed, steps, cfg, prompt, negative_prompt):
g = GraphBuilder()
# Create a basic image generation workflow using the input model, clip and vae
# 1. Setup text prompts using the provided CLIP model
positive_prompt = g.node("CLIPTextEncode",
text=prompt,
clip=clip)
negative_prompt = g.node("CLIPTextEncode",
text=negative_prompt,
clip=clip)
# 2. Create empty latent with specified size
empty_latent = g.node("EmptyLatentImage", width=512, height=512, batch_size=1)
# 3. Setup sampler and generate image latent
sampler = g.node("KSampler",
model=model,
positive=positive_prompt.out(0),
negative=negative_prompt.out(0),
latent_image=empty_latent.out(0),
seed=seed,
steps=steps,
cfg=cfg,
sampler_name="euler_ancestral",
scheduler="normal")
# 4. Decode latent to image using VAE
output = g.node("VAEDecode", samples=sampler.out(0), vae=vae)
return {
"result": (output.out(0),),
"expand": g.finalize(),
}
class TestSleep(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "sleep"
CATEGORY = "_for_testing"
async def sleep(self, value, seconds, unique_id):
pbar = ProgressBar(seconds, node_id=unique_id)
start = time.time()
expiration = start + seconds
now = start
while now < expiration:
now = time.time()
pbar.update_absolute(now - start)
await asyncio.sleep(0.01)
return (value,)
class TestParallelSleep(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image1": ("IMAGE", ),
"image2": ("IMAGE", ),
"image3": ("IMAGE", ),
"sleep1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
"sleep2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
"sleep3": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "parallel_sleep"
CATEGORY = "_for_testing"
OUTPUT_NODE = True
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
# Create a graph dynamically with three TestSleep nodes
g = GraphBuilder()
# Create sleep nodes for each duration and image
sleep_node1 = g.node("TestSleep", value=image1, seconds=sleep1)
sleep_node2 = g.node("TestSleep", value=image2, seconds=sleep2)
sleep_node3 = g.node("TestSleep", value=image3, seconds=sleep3)
# Blend the results using TestVariadicAverage
blend = g.node("TestVariadicAverage",
input1=sleep_node1.out(0),
input2=sleep_node2.out(0),
input3=sleep_node3.out(0))
return {
"result": (blend.out(0),),
"expand": g.finalize(),
}
TEST_NODE_CLASS_MAPPINGS = { TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages, "TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage, "TestVariadicAverage": TestVariadicAverage,
@ -345,6 +475,9 @@ TEST_NODE_CLASS_MAPPINGS = {
"TestCustomValidation5": TestCustomValidation5, "TestCustomValidation5": TestCustomValidation5,
"TestDynamicDependencyCycle": TestDynamicDependencyCycle, "TestDynamicDependencyCycle": TestDynamicDependencyCycle,
"TestMixedExpansionReturns": TestMixedExpansionReturns, "TestMixedExpansionReturns": TestMixedExpansionReturns,
"TestSamplingInExpansion": TestSamplingInExpansion,
"TestSleep": TestSleep,
"TestParallelSleep": TestParallelSleep,
} }
TEST_NODE_DISPLAY_NAME_MAPPINGS = { TEST_NODE_DISPLAY_NAME_MAPPINGS = {
@ -359,4 +492,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestCustomValidation5": "Custom Validation 5", "TestCustomValidation5": "Custom Validation 5",
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle", "TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
"TestMixedExpansionReturns": "Mixed Expansion Returns", "TestMixedExpansionReturns": "Mixed Expansion Returns",
"TestSamplingInExpansion": "Sampling In Expansion",
"TestSleep": "Test Sleep",
"TestParallelSleep": "Test Parallel Sleep",
} }