mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 16:26:39 +00:00
Support for async execution functions
This commit adds support for node execution functions defined as async. When a node's execution function is defined as async, we can continue executing other nodes while it is processing. Standard uses of `await` should "just work", but people will still have to be careful if they spawn actual threads. Because torch doesn't really have async/await versions of functions, this won't particularly help with most locally-executing nodes, but it does work for e.g. web requests to other machines. In addition to the execute function, the `VALIDATE_INPUTS` and `check_lazy_status` functions can also be defined as async, though we'll only resolve one node at a time right now for those.
This commit is contained in:
parent
772de7c006
commit
46c8311d14
@ -997,11 +997,12 @@ def set_progress_bar_global_hook(function):
|
|||||||
PROGRESS_BAR_HOOK = function
|
PROGRESS_BAR_HOOK = function
|
||||||
|
|
||||||
class ProgressBar:
|
class ProgressBar:
|
||||||
def __init__(self, total):
|
def __init__(self, total, node_id=None):
|
||||||
global PROGRESS_BAR_HOOK
|
global PROGRESS_BAR_HOOK
|
||||||
self.total = total
|
self.total = total
|
||||||
self.current = 0
|
self.current = 0
|
||||||
self.hook = PROGRESS_BAR_HOOK
|
self.hook = PROGRESS_BAR_HOOK
|
||||||
|
self.node_id = node_id
|
||||||
|
|
||||||
def update_absolute(self, value, total=None, preview=None):
|
def update_absolute(self, value, total=None, preview=None):
|
||||||
if total is not None:
|
if total is not None:
|
||||||
@ -1010,7 +1011,7 @@ class ProgressBar:
|
|||||||
value = self.total
|
value = self.total
|
||||||
self.current = value
|
self.current = value
|
||||||
if self.hook is not None:
|
if self.hook is not None:
|
||||||
self.hook(self.current, self.total, preview)
|
self.hook(self.current, self.total, preview, node_id=self.node_id)
|
||||||
|
|
||||||
def update(self, value):
|
def update(self, value):
|
||||||
self.update_absolute(self.current + value)
|
self.update_absolute(self.current + value)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from typing import Sequence, Mapping, Dict
|
from typing import Sequence, Mapping, Dict
|
||||||
from comfy_execution.graph import DynamicPrompt
|
from comfy_execution.graph import DynamicPrompt
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
|
|
||||||
@ -16,12 +17,13 @@ def include_unique_id_in_input(class_type: str) -> bool:
|
|||||||
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
|
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
|
||||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
|
|
||||||
class CacheKeySet:
|
class CacheKeySet(ABC):
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
self.keys = {}
|
self.keys = {}
|
||||||
self.subcache_keys = {}
|
self.subcache_keys = {}
|
||||||
|
|
||||||
def add_keys(self, node_ids):
|
@abstractmethod
|
||||||
|
async def add_keys(self, node_ids):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def all_node_ids(self):
|
def all_node_ids(self):
|
||||||
@ -60,9 +62,8 @@ class CacheKeySetID(CacheKeySet):
|
|||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
self.add_keys(node_ids)
|
|
||||||
|
|
||||||
def add_keys(self, node_ids):
|
async def add_keys(self, node_ids):
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
if node_id in self.keys:
|
if node_id in self.keys:
|
||||||
continue
|
continue
|
||||||
@ -77,37 +78,36 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
self.is_changed_cache = is_changed_cache
|
self.is_changed_cache = is_changed_cache
|
||||||
self.add_keys(node_ids)
|
|
||||||
|
|
||||||
def include_node_id_in_input(self) -> bool:
|
def include_node_id_in_input(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def add_keys(self, node_ids):
|
async def add_keys(self, node_ids):
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
if node_id in self.keys:
|
if node_id in self.keys:
|
||||||
continue
|
continue
|
||||||
if not self.dynprompt.has_node(node_id):
|
if not self.dynprompt.has_node(node_id):
|
||||||
continue
|
continue
|
||||||
node = self.dynprompt.get_node(node_id)
|
node = self.dynprompt.get_node(node_id)
|
||||||
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id)
|
||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
|
|
||||||
def get_node_signature(self, dynprompt, node_id):
|
async def get_node_signature(self, dynprompt, node_id):
|
||||||
signature = []
|
signature = []
|
||||||
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
||||||
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||||
for ancestor_id in ancestors:
|
for ancestor_id in ancestors:
|
||||||
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||||
return to_hashable(signature)
|
return to_hashable(signature)
|
||||||
|
|
||||||
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||||
if not dynprompt.has_node(node_id):
|
if not dynprompt.has_node(node_id):
|
||||||
# This node doesn't exist -- we can't cache it.
|
# This node doesn't exist -- we can't cache it.
|
||||||
return [float("NaN")]
|
return [float("NaN")]
|
||||||
node = dynprompt.get_node(node_id)
|
node = dynprompt.get_node(node_id)
|
||||||
class_type = node["class_type"]
|
class_type = node["class_type"]
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
signature = [class_type, self.is_changed_cache.get(node_id)]
|
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
||||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||||
signature.append(node_id)
|
signature.append(node_id)
|
||||||
inputs = node["inputs"]
|
inputs = node["inputs"]
|
||||||
@ -150,9 +150,10 @@ class BasicCache:
|
|||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.subcaches = {}
|
self.subcaches = {}
|
||||||
|
|
||||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||||
|
await self.cache_key_set.add_keys(node_ids)
|
||||||
self.is_changed_cache = is_changed_cache
|
self.is_changed_cache = is_changed_cache
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
|
|
||||||
@ -201,13 +202,13 @@ class BasicCache:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _ensure_subcache(self, node_id, children_ids):
|
async def _ensure_subcache(self, node_id, children_ids):
|
||||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||||
subcache = self.subcaches.get(subcache_key, None)
|
subcache = self.subcaches.get(subcache_key, None)
|
||||||
if subcache is None:
|
if subcache is None:
|
||||||
subcache = BasicCache(self.key_class)
|
subcache = BasicCache(self.key_class)
|
||||||
self.subcaches[subcache_key] = subcache
|
self.subcaches[subcache_key] = subcache
|
||||||
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||||
return subcache
|
return subcache
|
||||||
|
|
||||||
def _get_subcache(self, node_id):
|
def _get_subcache(self, node_id):
|
||||||
@ -259,10 +260,10 @@ class HierarchicalCache(BasicCache):
|
|||||||
assert cache is not None
|
assert cache is not None
|
||||||
cache._set_immediate(node_id, value)
|
cache._set_immediate(node_id, value)
|
||||||
|
|
||||||
def ensure_subcache_for(self, node_id, children_ids):
|
async def ensure_subcache_for(self, node_id, children_ids):
|
||||||
cache = self._get_cache_for(node_id)
|
cache = self._get_cache_for(node_id)
|
||||||
assert cache is not None
|
assert cache is not None
|
||||||
return cache._ensure_subcache(node_id, children_ids)
|
return await cache._ensure_subcache(node_id, children_ids)
|
||||||
|
|
||||||
class LRUCache(BasicCache):
|
class LRUCache(BasicCache):
|
||||||
def __init__(self, key_class, max_size=100):
|
def __init__(self, key_class, max_size=100):
|
||||||
@ -273,8 +274,8 @@ class LRUCache(BasicCache):
|
|||||||
self.used_generation = {}
|
self.used_generation = {}
|
||||||
self.children = {}
|
self.children = {}
|
||||||
|
|
||||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||||
self.generation += 1
|
self.generation += 1
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
self._mark_used(node_id)
|
self._mark_used(node_id)
|
||||||
@ -303,11 +304,11 @@ class LRUCache(BasicCache):
|
|||||||
self._mark_used(node_id)
|
self._mark_used(node_id)
|
||||||
return self._set_immediate(node_id, value)
|
return self._set_immediate(node_id, value)
|
||||||
|
|
||||||
def ensure_subcache_for(self, node_id, children_ids):
|
async def ensure_subcache_for(self, node_id, children_ids):
|
||||||
# Just uses subcaches for tracking 'live' nodes
|
# Just uses subcaches for tracking 'live' nodes
|
||||||
super()._ensure_subcache(node_id, children_ids)
|
await super()._ensure_subcache(node_id, children_ids)
|
||||||
|
|
||||||
self.cache_key_set.add_keys(children_ids)
|
await self.cache_key_set.add_keys(children_ids)
|
||||||
self._mark_used(node_id)
|
self._mark_used(node_id)
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
self.children[cache_key] = []
|
self.children[cache_key] = []
|
||||||
@ -337,7 +338,7 @@ class DependencyAwareCache(BasicCache):
|
|||||||
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
||||||
self.executed_nodes = set() # Tracks nodes that have been executed
|
self.executed_nodes = set() # Tracks nodes that have been executed
|
||||||
|
|
||||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||||
"""
|
"""
|
||||||
Clear the entire cache and rebuild the dependency graph.
|
Clear the entire cache and rebuild the dependency graph.
|
||||||
|
|
||||||
@ -354,7 +355,7 @@ class DependencyAwareCache(BasicCache):
|
|||||||
self.executed_nodes.clear()
|
self.executed_nodes.clear()
|
||||||
|
|
||||||
# Call the parent method to initialize the cache with the new prompt
|
# Call the parent method to initialize the cache with the new prompt
|
||||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||||
|
|
||||||
# Rebuild the dependency graph
|
# Rebuild the dependency graph
|
||||||
self._build_dependency_graph(dynprompt, node_ids)
|
self._build_dependency_graph(dynprompt, node_ids)
|
||||||
@ -405,7 +406,7 @@ class DependencyAwareCache(BasicCache):
|
|||||||
"""
|
"""
|
||||||
return self._get_immediate(node_id)
|
return self._get_immediate(node_id)
|
||||||
|
|
||||||
def ensure_subcache_for(self, node_id, children_ids):
|
async def ensure_subcache_for(self, node_id, children_ids):
|
||||||
"""
|
"""
|
||||||
Ensure a subcache exists for a node and update dependencies.
|
Ensure a subcache exists for a node and update dependencies.
|
||||||
|
|
||||||
@ -416,7 +417,7 @@ class DependencyAwareCache(BasicCache):
|
|||||||
Returns:
|
Returns:
|
||||||
The subcache object for the node.
|
The subcache object for the node.
|
||||||
"""
|
"""
|
||||||
subcache = super()._ensure_subcache(node_id, children_ids)
|
subcache = await super()._ensure_subcache(node_id, children_ids)
|
||||||
for child_id in children_ids:
|
for child_id in children_ids:
|
||||||
self.descendants[node_id].add(child_id)
|
self.descendants[node_id].add(child_id)
|
||||||
self.ancestors[child_id].add(node_id)
|
self.ancestors[child_id].add(node_id)
|
||||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
from typing import Type, Literal
|
from typing import Type, Literal
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
|
import asyncio
|
||||||
from comfy_execution.graph_utils import is_link
|
from comfy_execution.graph_utils import is_link
|
||||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||||
|
|
||||||
@ -100,6 +101,8 @@ class TopologicalSort:
|
|||||||
self.pendingNodes = {}
|
self.pendingNodes = {}
|
||||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||||
self.blocking = {} # Which nodes are blocked by this node
|
self.blocking = {} # Which nodes are blocked by this node
|
||||||
|
self.externalBlocks = 0
|
||||||
|
self.unblockedEvent = asyncio.Event()
|
||||||
|
|
||||||
def get_input_info(self, unique_id, input_name):
|
def get_input_info(self, unique_id, input_name):
|
||||||
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
||||||
@ -153,6 +156,16 @@ class TopologicalSort:
|
|||||||
for link in links:
|
for link in links:
|
||||||
self.add_strong_link(*link)
|
self.add_strong_link(*link)
|
||||||
|
|
||||||
|
def add_external_block(self, node_id):
|
||||||
|
assert node_id in self.blockCount, "Can't add external block to a node that isn't pending"
|
||||||
|
self.externalBlocks += 1
|
||||||
|
self.blockCount[node_id] += 1
|
||||||
|
def unblock():
|
||||||
|
self.externalBlocks -= 1
|
||||||
|
self.blockCount[node_id] -= 1
|
||||||
|
self.unblockedEvent.set()
|
||||||
|
return unblock
|
||||||
|
|
||||||
def is_cached(self, node_id):
|
def is_cached(self, node_id):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -181,11 +194,16 @@ class ExecutionList(TopologicalSort):
|
|||||||
def is_cached(self, node_id):
|
def is_cached(self, node_id):
|
||||||
return self.output_cache.get(node_id) is not None
|
return self.output_cache.get(node_id) is not None
|
||||||
|
|
||||||
def stage_node_execution(self):
|
async def stage_node_execution(self):
|
||||||
assert self.staged_node_id is None
|
assert self.staged_node_id is None
|
||||||
if self.is_empty():
|
if self.is_empty():
|
||||||
return None, None, None
|
return None, None, None
|
||||||
available = self.get_ready_nodes()
|
available = self.get_ready_nodes()
|
||||||
|
while len(available) == 0 and self.externalBlocks > 0:
|
||||||
|
# Wait for an external block to be released
|
||||||
|
await self.unblockedEvent.wait()
|
||||||
|
self.unblockedEvent.clear()
|
||||||
|
available = self.get_ready_nodes()
|
||||||
if len(available) == 0:
|
if len(available) == 0:
|
||||||
cycled_nodes = self.get_nodes_in_cycle()
|
cycled_nodes = self.get_nodes_in_cycle()
|
||||||
# Because cycles composed entirely of static nodes are caught during initial validation,
|
# Because cycles composed entirely of static nodes are caught during initial validation,
|
||||||
|
288
comfy_execution/progress.py
Normal file
288
comfy_execution/progress.py
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
from typing import TypedDict, Dict, Optional
|
||||||
|
from typing_extensions import override
|
||||||
|
from PIL import Image
|
||||||
|
from enum import Enum
|
||||||
|
from abc import ABC
|
||||||
|
from tqdm import tqdm
|
||||||
|
from comfy_execution.graph import DynamicPrompt
|
||||||
|
from protocol import BinaryEventTypes
|
||||||
|
|
||||||
|
class NodeState(Enum):
|
||||||
|
Pending = "pending"
|
||||||
|
Running = "running"
|
||||||
|
Finished = "finished"
|
||||||
|
Error = "error"
|
||||||
|
|
||||||
|
class NodeProgressState(TypedDict):
|
||||||
|
"""
|
||||||
|
A class to represent the state of a node's progress.
|
||||||
|
"""
|
||||||
|
state: NodeState
|
||||||
|
value: float
|
||||||
|
max: float
|
||||||
|
|
||||||
|
class ProgressHandler(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for progress handlers.
|
||||||
|
Progress handlers receive progress updates and display them in various ways.
|
||||||
|
"""
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
self.enabled = True
|
||||||
|
|
||||||
|
def set_registry(self, registry: "ProgressRegistry"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||||
|
"""Called when a node starts processing"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update_handler(self, node_id: str, value: float, max_value: float,
|
||||||
|
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
|
||||||
|
"""Called when a node's progress is updated"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||||
|
"""Called when a node finishes processing"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Called when the progress registry is reset"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def enable(self):
|
||||||
|
"""Enable this handler"""
|
||||||
|
self.enabled = True
|
||||||
|
|
||||||
|
def disable(self):
|
||||||
|
"""Disable this handler"""
|
||||||
|
self.enabled = False
|
||||||
|
|
||||||
|
class CLIProgressHandler(ProgressHandler):
|
||||||
|
"""
|
||||||
|
Handler that displays progress using tqdm progress bars in the CLI.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("cli")
|
||||||
|
self.progress_bars: Dict[str, tqdm] = {}
|
||||||
|
|
||||||
|
@override
|
||||||
|
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||||
|
# Create a new tqdm progress bar
|
||||||
|
if node_id not in self.progress_bars:
|
||||||
|
self.progress_bars[node_id] = tqdm(
|
||||||
|
total=state["max"],
|
||||||
|
desc=f"Node {node_id}",
|
||||||
|
unit="steps",
|
||||||
|
leave=True,
|
||||||
|
position=len(self.progress_bars)
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def update_handler(self, node_id: str, value: float, max_value: float,
|
||||||
|
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
|
||||||
|
# Handle case where start_handler wasn't called
|
||||||
|
if node_id not in self.progress_bars:
|
||||||
|
self.progress_bars[node_id] = tqdm(
|
||||||
|
total=max_value,
|
||||||
|
desc=f"Node {node_id}",
|
||||||
|
unit="steps",
|
||||||
|
leave=True,
|
||||||
|
position=len(self.progress_bars)
|
||||||
|
)
|
||||||
|
self.progress_bars[node_id].update(value)
|
||||||
|
else:
|
||||||
|
# Update existing progress bar
|
||||||
|
if max_value != self.progress_bars[node_id].total:
|
||||||
|
self.progress_bars[node_id].total = max_value
|
||||||
|
# Calculate the update amount (difference from current position)
|
||||||
|
current_position = self.progress_bars[node_id].n
|
||||||
|
update_amount = value - current_position
|
||||||
|
if update_amount > 0:
|
||||||
|
self.progress_bars[node_id].update(update_amount)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||||
|
# Complete and close the progress bar if it exists
|
||||||
|
if node_id in self.progress_bars:
|
||||||
|
# Ensure the bar shows 100% completion
|
||||||
|
remaining = state["max"] - self.progress_bars[node_id].n
|
||||||
|
if remaining > 0:
|
||||||
|
self.progress_bars[node_id].update(remaining)
|
||||||
|
self.progress_bars[node_id].close()
|
||||||
|
del self.progress_bars[node_id]
|
||||||
|
|
||||||
|
@override
|
||||||
|
def reset(self):
|
||||||
|
# Close all progress bars
|
||||||
|
for bar in self.progress_bars.values():
|
||||||
|
bar.close()
|
||||||
|
self.progress_bars.clear()
|
||||||
|
|
||||||
|
class WebUIProgressHandler(ProgressHandler):
|
||||||
|
"""
|
||||||
|
Handler that sends progress updates to the WebUI via WebSockets.
|
||||||
|
"""
|
||||||
|
def __init__(self, server_instance):
|
||||||
|
super().__init__("webui")
|
||||||
|
self.server_instance = server_instance
|
||||||
|
|
||||||
|
def set_registry(self, registry: "ProgressRegistry"):
|
||||||
|
self.registry = registry
|
||||||
|
|
||||||
|
def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]):
|
||||||
|
"""Send the current progress state to the client"""
|
||||||
|
if self.server_instance is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Only send info for non-pending nodes
|
||||||
|
active_nodes = {
|
||||||
|
node_id: {
|
||||||
|
"value": state["value"],
|
||||||
|
"max": state["max"],
|
||||||
|
"state": state["state"].value,
|
||||||
|
"node_id": node_id,
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||||
|
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||||
|
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id)
|
||||||
|
}
|
||||||
|
for node_id, state in nodes.items()
|
||||||
|
if state["state"] != NodeState.Pending
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send a combined progress_state message with all node states
|
||||||
|
self.server_instance.send_sync("progress_state", {
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
"nodes": active_nodes
|
||||||
|
})
|
||||||
|
|
||||||
|
@override
|
||||||
|
def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||||
|
# Send progress state of all nodes
|
||||||
|
if self.registry:
|
||||||
|
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def update_handler(self, node_id: str, value: float, max_value: float,
|
||||||
|
state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None):
|
||||||
|
# Send progress state of all nodes
|
||||||
|
if self.registry:
|
||||||
|
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||||
|
if image:
|
||||||
|
metadata = {
|
||||||
|
"node_id": node_id,
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||||
|
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||||
|
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id)
|
||||||
|
}
|
||||||
|
self.server_instance.send_sync(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, (image, metadata), self.server_instance.client_id)
|
||||||
|
|
||||||
|
|
||||||
|
@override
|
||||||
|
def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str):
|
||||||
|
# Send progress state of all nodes
|
||||||
|
if self.registry:
|
||||||
|
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||||
|
|
||||||
|
class ProgressRegistry:
|
||||||
|
"""
|
||||||
|
Registry that maintains node progress state and notifies registered handlers.
|
||||||
|
"""
|
||||||
|
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt):
|
||||||
|
self.prompt_id = prompt_id
|
||||||
|
self.dynprompt = dynprompt
|
||||||
|
self.nodes: Dict[str, NodeProgressState] = {}
|
||||||
|
self.handlers: Dict[str, ProgressHandler] = {}
|
||||||
|
|
||||||
|
def register_handler(self, handler: ProgressHandler) -> None:
|
||||||
|
"""Register a progress handler"""
|
||||||
|
self.handlers[handler.name] = handler
|
||||||
|
|
||||||
|
def unregister_handler(self, handler_name: str) -> None:
|
||||||
|
"""Unregister a progress handler"""
|
||||||
|
if handler_name in self.handlers:
|
||||||
|
# Allow handler to clean up resources
|
||||||
|
self.handlers[handler_name].reset()
|
||||||
|
del self.handlers[handler_name]
|
||||||
|
|
||||||
|
def enable_handler(self, handler_name: str) -> None:
|
||||||
|
"""Enable a progress handler"""
|
||||||
|
if handler_name in self.handlers:
|
||||||
|
self.handlers[handler_name].enable()
|
||||||
|
|
||||||
|
def disable_handler(self, handler_name: str) -> None:
|
||||||
|
"""Disable a progress handler"""
|
||||||
|
if handler_name in self.handlers:
|
||||||
|
self.handlers[handler_name].disable()
|
||||||
|
|
||||||
|
def ensure_entry(self, node_id: str) -> NodeProgressState:
|
||||||
|
"""Ensure a node entry exists"""
|
||||||
|
if node_id not in self.nodes:
|
||||||
|
self.nodes[node_id] = NodeProgressState(
|
||||||
|
state = NodeState.Pending,
|
||||||
|
value = 0,
|
||||||
|
max = 1
|
||||||
|
)
|
||||||
|
return self.nodes[node_id]
|
||||||
|
|
||||||
|
def start_progress(self, node_id: str) -> None:
|
||||||
|
"""Start progress tracking for a node"""
|
||||||
|
entry = self.ensure_entry(node_id)
|
||||||
|
entry["state"] = NodeState.Running
|
||||||
|
entry["value"] = 0.0
|
||||||
|
entry["max"] = 1.0
|
||||||
|
|
||||||
|
# Notify all enabled handlers
|
||||||
|
for handler in self.handlers.values():
|
||||||
|
if handler.enabled:
|
||||||
|
handler.start_handler(node_id, entry, self.prompt_id)
|
||||||
|
|
||||||
|
def update_progress(self, node_id: str, value: float, max_value: float, image: Optional[Image.Image]) -> None:
|
||||||
|
"""Update progress for a node"""
|
||||||
|
entry = self.ensure_entry(node_id)
|
||||||
|
entry["state"] = NodeState.Running
|
||||||
|
entry["value"] = value
|
||||||
|
entry["max"] = max_value
|
||||||
|
|
||||||
|
# Notify all enabled handlers
|
||||||
|
for handler in self.handlers.values():
|
||||||
|
if handler.enabled:
|
||||||
|
handler.update_handler(node_id, value, max_value, entry, self.prompt_id, image)
|
||||||
|
|
||||||
|
def finish_progress(self, node_id: str) -> None:
|
||||||
|
"""Finish progress tracking for a node"""
|
||||||
|
entry = self.ensure_entry(node_id)
|
||||||
|
entry["state"] = NodeState.Finished
|
||||||
|
entry["value"] = entry["max"]
|
||||||
|
|
||||||
|
# Notify all enabled handlers
|
||||||
|
for handler in self.handlers.values():
|
||||||
|
if handler.enabled:
|
||||||
|
handler.finish_handler(node_id, entry, self.prompt_id)
|
||||||
|
|
||||||
|
def reset_handlers(self) -> None:
|
||||||
|
"""Reset all handlers"""
|
||||||
|
for handler in self.handlers.values():
|
||||||
|
handler.reset()
|
||||||
|
|
||||||
|
# Global registry instance
|
||||||
|
global_progress_registry: ProgressRegistry = ProgressRegistry(prompt_id="", dynprompt=DynamicPrompt({}))
|
||||||
|
|
||||||
|
def reset_progress_state(prompt_id: str, dynprompt: DynamicPrompt) -> None:
|
||||||
|
global global_progress_registry
|
||||||
|
|
||||||
|
# Reset existing handlers if registry exists
|
||||||
|
if global_progress_registry is not None:
|
||||||
|
global_progress_registry.reset_handlers()
|
||||||
|
|
||||||
|
# Create new registry
|
||||||
|
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
|
||||||
|
|
||||||
|
def add_progress_handler(handler: ProgressHandler) -> None:
|
||||||
|
handler.set_registry(global_progress_registry)
|
||||||
|
global_progress_registry.register_handler(handler)
|
||||||
|
|
||||||
|
def get_progress_state() -> ProgressRegistry:
|
||||||
|
return global_progress_registry
|
46
comfy_execution/utils.py
Normal file
46
comfy_execution/utils.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import contextvars
|
||||||
|
from typing import Optional, NamedTuple
|
||||||
|
|
||||||
|
class ExecutionContext(NamedTuple):
|
||||||
|
"""
|
||||||
|
Context information about the currently executing node.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
node_id: The ID of the currently executing node
|
||||||
|
list_index: The index in a list being processed (for operations on batches/lists)
|
||||||
|
"""
|
||||||
|
prompt_id: str
|
||||||
|
node_id: str
|
||||||
|
list_index: Optional[int]
|
||||||
|
|
||||||
|
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
|
||||||
|
|
||||||
|
def get_executing_context() -> Optional[ExecutionContext]:
|
||||||
|
return current_executing_context.get(None)
|
||||||
|
|
||||||
|
class CurrentNodeContext:
|
||||||
|
"""
|
||||||
|
Context manager for setting the current executing node context.
|
||||||
|
|
||||||
|
Sets the current_executing_context on enter and resets it on exit.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
with CurrentNodeContext(node_id="123", list_index=0):
|
||||||
|
# Code that should run with the current node context set
|
||||||
|
process_image()
|
||||||
|
"""
|
||||||
|
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
|
||||||
|
self.context = ExecutionContext(
|
||||||
|
prompt_id= prompt_id,
|
||||||
|
node_id= node_id,
|
||||||
|
list_index= list_index
|
||||||
|
)
|
||||||
|
self.token = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.token = current_executing_context.set(self.context)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if self.token is not None:
|
||||||
|
current_executing_context.reset(self.token)
|
125
execution.py
125
execution.py
@ -8,12 +8,14 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Literal, NamedTuple, Optional
|
from typing import List, Literal, NamedTuple, Optional
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_execution.caching import (
|
from comfy_execution.caching import (
|
||||||
|
BasicCache,
|
||||||
CacheKeySetID,
|
CacheKeySetID,
|
||||||
CacheKeySetInputSignature,
|
CacheKeySetInputSignature,
|
||||||
DependencyAwareCache,
|
DependencyAwareCache,
|
||||||
@ -28,6 +30,8 @@ from comfy_execution.graph import (
|
|||||||
)
|
)
|
||||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
|
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||||
|
from comfy_execution.utils import CurrentNodeContext
|
||||||
|
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
@ -39,12 +43,13 @@ class DuplicateNodeError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
class IsChangedCache:
|
class IsChangedCache:
|
||||||
def __init__(self, dynprompt, outputs_cache):
|
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache):
|
||||||
|
self.prompt_id = prompt_id
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
self.outputs_cache = outputs_cache
|
self.outputs_cache = outputs_cache
|
||||||
self.is_changed = {}
|
self.is_changed = {}
|
||||||
|
|
||||||
def get(self, node_id):
|
async def get(self, node_id):
|
||||||
if node_id in self.is_changed:
|
if node_id in self.is_changed:
|
||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
@ -62,7 +67,8 @@ class IsChangedCache:
|
|||||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||||
try:
|
try:
|
||||||
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED")
|
||||||
|
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning("WARNING: {}".format(e))
|
logging.warning("WARNING: {}".format(e))
|
||||||
@ -164,7 +170,19 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
|
|
||||||
map_node_over_list = None #Don't hook this please
|
map_node_over_list = None #Don't hook this please
|
||||||
|
|
||||||
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
async def resolve_map_node_over_list_results(results):
|
||||||
|
remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()]
|
||||||
|
if len(remaining) == 0:
|
||||||
|
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||||
|
else:
|
||||||
|
done, pending = await asyncio.wait(remaining)
|
||||||
|
for task in done:
|
||||||
|
exc = task.exception()
|
||||||
|
if exc is not None:
|
||||||
|
raise exc
|
||||||
|
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||||
|
|
||||||
|
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||||
# check if node wants the lists
|
# check if node wants the lists
|
||||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||||
|
|
||||||
@ -178,7 +196,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
|||||||
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
|
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
def process_inputs(inputs, index=None, input_is_list=False):
|
async def process_inputs(inputs, index=None, input_is_list=False):
|
||||||
if allow_interrupt:
|
if allow_interrupt:
|
||||||
nodes.before_node_execution()
|
nodes.before_node_execution()
|
||||||
execution_block = None
|
execution_block = None
|
||||||
@ -194,20 +212,37 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
|
|||||||
if execution_block is None:
|
if execution_block is None:
|
||||||
if pre_execute_cb is not None and index is not None:
|
if pre_execute_cb is not None and index is not None:
|
||||||
pre_execute_cb(index)
|
pre_execute_cb(index)
|
||||||
results.append(getattr(obj, func)(**inputs))
|
f = getattr(obj, func)
|
||||||
|
if inspect.iscoroutinefunction(f):
|
||||||
|
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||||
|
with CurrentNodeContext(prompt_id, unique_id, list_index):
|
||||||
|
return await f(**args)
|
||||||
|
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
|
||||||
|
# Give the task a chance to execute without yielding
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
if task.done():
|
||||||
|
result = task.result()
|
||||||
|
results.append(result)
|
||||||
|
else:
|
||||||
|
results.append(task)
|
||||||
|
else:
|
||||||
|
with CurrentNodeContext(prompt_id, unique_id, index):
|
||||||
|
result = f(**inputs)
|
||||||
|
results.append(result)
|
||||||
else:
|
else:
|
||||||
results.append(execution_block)
|
results.append(execution_block)
|
||||||
|
|
||||||
if input_is_list:
|
if input_is_list:
|
||||||
process_inputs(input_data_all, 0, input_is_list=input_is_list)
|
await process_inputs(input_data_all, 0, input_is_list=input_is_list)
|
||||||
elif max_len_input == 0:
|
elif max_len_input == 0:
|
||||||
process_inputs({})
|
await process_inputs({})
|
||||||
else:
|
else:
|
||||||
for i in range(max_len_input):
|
for i in range(max_len_input):
|
||||||
input_dict = slice_dict(input_data_all, i)
|
input_dict = slice_dict(input_data_all, i)
|
||||||
process_inputs(input_dict, i)
|
await process_inputs(input_dict, i)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def merge_result_data(results, obj):
|
def merge_result_data(results, obj):
|
||||||
# check which outputs need concatenating
|
# check which outputs need concatenating
|
||||||
output = []
|
output = []
|
||||||
@ -229,11 +264,18 @@ def merge_result_data(results, obj):
|
|||||||
output.append([o[i] for o in results])
|
output.append([o[i] for o in results])
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||||
|
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||||
|
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||||
|
if has_pending_task:
|
||||||
|
return return_values, {}, False, has_pending_task
|
||||||
|
output, ui, has_subgraph = get_output_from_returns(return_values, obj)
|
||||||
|
return output, ui, has_subgraph, False
|
||||||
|
|
||||||
|
def get_output_from_returns(return_values, obj):
|
||||||
results = []
|
results = []
|
||||||
uis = []
|
uis = []
|
||||||
subgraph_results = []
|
subgraph_results = []
|
||||||
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
|
||||||
has_subgraph = False
|
has_subgraph = False
|
||||||
for i in range(len(return_values)):
|
for i in range(len(return_values)):
|
||||||
r = return_values[i]
|
r = return_values[i]
|
||||||
@ -267,6 +309,10 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
|
|||||||
else:
|
else:
|
||||||
output = []
|
output = []
|
||||||
ui = dict()
|
ui = dict()
|
||||||
|
# TODO: Think there's an existing bug here
|
||||||
|
# If we're performing a subgraph expansion, we probably shouldn't be returning UI values yet.
|
||||||
|
# They'll get cached without the completed subgraphs. It's an edge case and I'm not aware of
|
||||||
|
# any nodes that use both subgraph expansion and custom UI outputs, but might be a problem in the future.
|
||||||
if len(uis) > 0:
|
if len(uis) > 0:
|
||||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||||
return output, ui, has_subgraph
|
return output, ui, has_subgraph
|
||||||
@ -279,7 +325,7 @@ def format_value(x):
|
|||||||
else:
|
else:
|
||||||
return str(x)
|
return str(x)
|
||||||
|
|
||||||
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
|
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||||
@ -291,11 +337,16 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
|||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
cached_output = caches.ui.get(unique_id) or {}
|
cached_output = caches.ui.get(unique_id) or {}
|
||||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||||
|
get_progress_state().finish_progress(unique_id)
|
||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
input_data_all = None
|
input_data_all = None
|
||||||
try:
|
try:
|
||||||
if unique_id in pending_subgraph_results:
|
if unique_id in pending_async_nodes:
|
||||||
|
results = [r.result() if isinstance(r, asyncio.Task) else r for r in pending_async_nodes[unique_id]]
|
||||||
|
del pending_async_nodes[unique_id]
|
||||||
|
output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def)
|
||||||
|
elif unique_id in pending_subgraph_results:
|
||||||
cached_results = pending_subgraph_results[unique_id]
|
cached_results = pending_subgraph_results[unique_id]
|
||||||
resolved_outputs = []
|
resolved_outputs = []
|
||||||
for is_subgraph, result in cached_results:
|
for is_subgraph, result in cached_results:
|
||||||
@ -317,6 +368,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
|||||||
output_ui = []
|
output_ui = []
|
||||||
has_subgraph = False
|
has_subgraph = False
|
||||||
else:
|
else:
|
||||||
|
get_progress_state().start_progress(unique_id)
|
||||||
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = display_node_id
|
server.last_node_id = display_node_id
|
||||||
@ -328,7 +380,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
|||||||
caches.objects.set(unique_id, obj)
|
caches.objects.set(unique_id, obj)
|
||||||
|
|
||||||
if hasattr(obj, "check_lazy_status"):
|
if hasattr(obj, "check_lazy_status"):
|
||||||
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||||
|
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||||
x not in input_data_all or x in missing_keys
|
x not in input_data_all or x in missing_keys
|
||||||
@ -357,8 +410,18 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
|||||||
else:
|
else:
|
||||||
return block
|
return block
|
||||||
def pre_execute_cb(call_index):
|
def pre_execute_cb(call_index):
|
||||||
|
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||||
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||||
|
if has_pending_tasks:
|
||||||
|
pending_async_nodes[unique_id] = output_data
|
||||||
|
unblock = execution_list.add_external_block(unique_id)
|
||||||
|
async def await_completion():
|
||||||
|
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
unblock()
|
||||||
|
asyncio.create_task(await_completion())
|
||||||
|
return (ExecutionResult.PENDING, None, None)
|
||||||
if len(output_ui) > 0:
|
if len(output_ui) > 0:
|
||||||
caches.ui.set(unique_id, {
|
caches.ui.set(unique_id, {
|
||||||
"meta": {
|
"meta": {
|
||||||
@ -401,7 +464,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
|||||||
cached_outputs.append((True, node_outputs))
|
cached_outputs.append((True, node_outputs))
|
||||||
new_node_ids = set(new_node_ids)
|
new_node_ids = set(new_node_ids)
|
||||||
for cache in caches.all:
|
for cache in caches.all:
|
||||||
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
|
subcache = await cache.ensure_subcache_for(unique_id, new_node_ids)
|
||||||
|
subcache.clean_unused()
|
||||||
for node_id in new_output_ids:
|
for node_id in new_output_ids:
|
||||||
execution_list.add_node(node_id)
|
execution_list.add_node(node_id)
|
||||||
for link in new_output_links:
|
for link in new_output_links:
|
||||||
@ -446,6 +510,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
|||||||
|
|
||||||
return (ExecutionResult.FAILURE, error_details, ex)
|
return (ExecutionResult.FAILURE, error_details, ex)
|
||||||
|
|
||||||
|
get_progress_state().finish_progress(unique_id)
|
||||||
executed.add(unique_id)
|
executed.add(unique_id)
|
||||||
|
|
||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
@ -500,6 +565,11 @@ class PromptExecutor:
|
|||||||
self.add_message("execution_error", mes, broadcast=False)
|
self.add_message("execution_error", mes, broadcast=False)
|
||||||
|
|
||||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||||
|
asyncio_loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(asyncio_loop)
|
||||||
|
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||||
|
|
||||||
|
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||||
nodes.interrupt_processing(False)
|
nodes.interrupt_processing(False)
|
||||||
|
|
||||||
if "client_id" in extra_data:
|
if "client_id" in extra_data:
|
||||||
@ -512,9 +582,11 @@ class PromptExecutor:
|
|||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
dynamic_prompt = DynamicPrompt(prompt)
|
dynamic_prompt = DynamicPrompt(prompt)
|
||||||
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
reset_progress_state(prompt_id, dynamic_prompt)
|
||||||
|
add_progress_handler(WebUIProgressHandler(self.server))
|
||||||
|
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||||
for cache in self.caches.all:
|
for cache in self.caches.all:
|
||||||
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||||
cache.clean_unused()
|
cache.clean_unused()
|
||||||
|
|
||||||
cached_nodes = []
|
cached_nodes = []
|
||||||
@ -527,6 +599,7 @@ class PromptExecutor:
|
|||||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||||
broadcast=False)
|
broadcast=False)
|
||||||
pending_subgraph_results = {}
|
pending_subgraph_results = {}
|
||||||
|
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||||
executed = set()
|
executed = set()
|
||||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||||
current_outputs = self.caches.outputs.all_node_ids()
|
current_outputs = self.caches.outputs.all_node_ids()
|
||||||
@ -534,12 +607,13 @@ class PromptExecutor:
|
|||||||
execution_list.add_node(node_id)
|
execution_list.add_node(node_id)
|
||||||
|
|
||||||
while not execution_list.is_empty():
|
while not execution_list.is_empty():
|
||||||
node_id, error, ex = execution_list.stage_node_execution()
|
node_id, error, ex = await execution_list.stage_node_execution()
|
||||||
if error is not None:
|
if error is not None:
|
||||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
break
|
break
|
||||||
|
|
||||||
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
assert node_id is not None, "Node ID should not be None at this point"
|
||||||
|
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
||||||
self.success = result != ExecutionResult.FAILURE
|
self.success = result != ExecutionResult.FAILURE
|
||||||
if result == ExecutionResult.FAILURE:
|
if result == ExecutionResult.FAILURE:
|
||||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
@ -569,7 +643,7 @@ class PromptExecutor:
|
|||||||
comfy.model_management.unload_all_models()
|
comfy.model_management.unload_all_models()
|
||||||
|
|
||||||
|
|
||||||
def validate_inputs(prompt, item, validated):
|
async def validate_inputs(prompt_id, prompt, item, validated):
|
||||||
unique_id = item
|
unique_id = item
|
||||||
if unique_id in validated:
|
if unique_id in validated:
|
||||||
return validated[unique_id]
|
return validated[unique_id]
|
||||||
@ -646,7 +720,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
r = validate_inputs(prompt, o_id, validated)
|
r = await validate_inputs(prompt_id, prompt, o_id, validated)
|
||||||
if r[0] is False:
|
if r[0] is False:
|
||||||
# `r` will be set in `validated[o_id]` already
|
# `r` will be set in `validated[o_id]` already
|
||||||
valid = False
|
valid = False
|
||||||
@ -771,7 +845,8 @@ def validate_inputs(prompt, item, validated):
|
|||||||
input_filtered['input_types'] = [received_types]
|
input_filtered['input_types'] = [received_types]
|
||||||
|
|
||||||
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||||
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||||
|
ret = await resolve_map_node_over_list_results(ret)
|
||||||
for x in input_filtered:
|
for x in input_filtered:
|
||||||
for i, r in enumerate(ret):
|
for i, r in enumerate(ret):
|
||||||
if r is not True and not isinstance(r, ExecutionBlocker):
|
if r is not True and not isinstance(r, ExecutionBlocker):
|
||||||
@ -804,7 +879,7 @@ def full_type_name(klass):
|
|||||||
return klass.__qualname__
|
return klass.__qualname__
|
||||||
return module + '.' + klass.__qualname__
|
return module + '.' + klass.__qualname__
|
||||||
|
|
||||||
def validate_prompt(prompt):
|
async def validate_prompt(prompt_id, prompt):
|
||||||
outputs = set()
|
outputs = set()
|
||||||
for x in prompt:
|
for x in prompt:
|
||||||
if 'class_type' not in prompt[x]:
|
if 'class_type' not in prompt[x]:
|
||||||
@ -847,7 +922,7 @@ def validate_prompt(prompt):
|
|||||||
valid = False
|
valid = False
|
||||||
reasons = []
|
reasons = []
|
||||||
try:
|
try:
|
||||||
m = validate_inputs(prompt, o, validated)
|
m = await validate_inputs(prompt_id, prompt, o, validated)
|
||||||
valid = m[0]
|
valid = m[0]
|
||||||
reasons = m[1]
|
reasons = m[1]
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
|
21
main.py
21
main.py
@ -11,6 +11,8 @@ import itertools
|
|||||||
import utils.extra_config
|
import utils.extra_config
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from comfy_execution.progress import get_progress_state
|
||||||
|
from comfy_execution.utils import get_executing_context
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||||
@ -131,7 +133,7 @@ import comfy.utils
|
|||||||
|
|
||||||
import execution
|
import execution
|
||||||
import server
|
import server
|
||||||
from server import BinaryEventTypes
|
from protocol import BinaryEventTypes
|
||||||
import nodes
|
import nodes
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfyui_version
|
import comfyui_version
|
||||||
@ -227,14 +229,25 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star
|
|||||||
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
|
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def hijack_progress(server_instance):
|
def hijack_progress(server_instance):
|
||||||
def hook(value, total, preview_image):
|
def hook(value, total, preview_image, prompt_id=None, node_id=None):
|
||||||
|
executing_context = get_executing_context()
|
||||||
|
if prompt_id is None and executing_context is not None:
|
||||||
|
prompt_id = executing_context.prompt_id
|
||||||
|
if node_id is None and executing_context is not None:
|
||||||
|
node_id = executing_context.node_id
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
|
if prompt_id is None:
|
||||||
|
prompt_id = server_instance.last_prompt_id
|
||||||
|
if node_id is None:
|
||||||
|
node_id = server_instance.last_node_id
|
||||||
|
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
|
||||||
|
get_progress_state().update_progress(node_id, value, total, preview_image)
|
||||||
|
|
||||||
server_instance.send_sync("progress", progress, server_instance.client_id)
|
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||||
if preview_image is not None:
|
if preview_image is not None:
|
||||||
|
# Also send old method for backward compatibility
|
||||||
|
# TODO - Remove after this repo is updated to frontend with metadata support
|
||||||
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
|
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
|
||||||
|
|
||||||
comfy.utils.set_progress_bar_global_hook(hook)
|
comfy.utils.set_progress_bar_global_hook(hook)
|
||||||
|
51
server.py
51
server.py
@ -35,11 +35,7 @@ from app.model_manager import ModelFileManager
|
|||||||
from app.custom_node_manager import CustomNodeManager
|
from app.custom_node_manager import CustomNodeManager
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
|
from protocol import BinaryEventTypes
|
||||||
class BinaryEventTypes:
|
|
||||||
PREVIEW_IMAGE = 1
|
|
||||||
UNENCODED_PREVIEW_IMAGE = 2
|
|
||||||
TEXT = 3
|
|
||||||
|
|
||||||
async def send_socket_catch_exception(function, message):
|
async def send_socket_catch_exception(function, message):
|
||||||
try:
|
try:
|
||||||
@ -643,7 +639,8 @@ class PromptServer():
|
|||||||
|
|
||||||
if "prompt" in json_data:
|
if "prompt" in json_data:
|
||||||
prompt = json_data["prompt"]
|
prompt = json_data["prompt"]
|
||||||
valid = execution.validate_prompt(prompt)
|
prompt_id = str(uuid.uuid4())
|
||||||
|
valid = await execution.validate_prompt(prompt_id, prompt)
|
||||||
extra_data = {}
|
extra_data = {}
|
||||||
if "extra_data" in json_data:
|
if "extra_data" in json_data:
|
||||||
extra_data = json_data["extra_data"]
|
extra_data = json_data["extra_data"]
|
||||||
@ -651,7 +648,6 @@ class PromptServer():
|
|||||||
if "client_id" in json_data:
|
if "client_id" in json_data:
|
||||||
extra_data["client_id"] = json_data["client_id"]
|
extra_data["client_id"] = json_data["client_id"]
|
||||||
if valid[0]:
|
if valid[0]:
|
||||||
prompt_id = str(uuid.uuid4())
|
|
||||||
outputs_to_execute = valid[2]
|
outputs_to_execute = valid[2]
|
||||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||||
@ -766,6 +762,10 @@ class PromptServer():
|
|||||||
async def send(self, event, data, sid=None):
|
async def send(self, event, data, sid=None):
|
||||||
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||||
await self.send_image(data, sid=sid)
|
await self.send_image(data, sid=sid)
|
||||||
|
elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
|
||||||
|
# data is (preview_image, metadata)
|
||||||
|
preview_image, metadata = data
|
||||||
|
await self.send_image_with_metadata(preview_image, metadata, sid=sid)
|
||||||
elif isinstance(data, (bytes, bytearray)):
|
elif isinstance(data, (bytes, bytearray)):
|
||||||
await self.send_bytes(event, data, sid)
|
await self.send_bytes(event, data, sid)
|
||||||
else:
|
else:
|
||||||
@ -804,6 +804,43 @@ class PromptServer():
|
|||||||
preview_bytes = bytesIO.getvalue()
|
preview_bytes = bytesIO.getvalue()
|
||||||
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
||||||
|
|
||||||
|
async def send_image_with_metadata(self, image_data, metadata=None, sid=None):
|
||||||
|
image_type = image_data[0]
|
||||||
|
image = image_data[1]
|
||||||
|
max_size = image_data[2]
|
||||||
|
if max_size is not None:
|
||||||
|
if hasattr(Image, 'Resampling'):
|
||||||
|
resampling = Image.Resampling.BILINEAR
|
||||||
|
else:
|
||||||
|
resampling = Image.Resampling.LANCZOS
|
||||||
|
|
||||||
|
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||||
|
|
||||||
|
mimetype = "image/png" if image_type == "PNG" else "image/jpeg"
|
||||||
|
|
||||||
|
# Prepare metadata
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
metadata["image_type"] = mimetype
|
||||||
|
|
||||||
|
# Serialize metadata as JSON
|
||||||
|
import json
|
||||||
|
metadata_json = json.dumps(metadata).encode('utf-8')
|
||||||
|
metadata_length = len(metadata_json)
|
||||||
|
|
||||||
|
# Prepare image data
|
||||||
|
bytesIO = BytesIO()
|
||||||
|
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
|
||||||
|
image_bytes = bytesIO.getvalue()
|
||||||
|
|
||||||
|
# Combine metadata and image
|
||||||
|
combined_data = bytearray()
|
||||||
|
combined_data.extend(struct.pack(">I", metadata_length))
|
||||||
|
combined_data.extend(metadata_json)
|
||||||
|
combined_data.extend(image_bytes)
|
||||||
|
|
||||||
|
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, combined_data, sid=sid)
|
||||||
|
|
||||||
async def send_bytes(self, event, data, sid=None):
|
async def send_bytes(self, event, data, sid=None):
|
||||||
message = self.encode_bytes(event, data)
|
message = self.encode_bytes(event, data)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Config for testing nodes
|
# Config for testing nodes
|
||||||
testing:
|
testing:
|
||||||
custom_nodes: tests/inference/testing_nodes
|
custom_nodes: testing_nodes
|
||||||
|
|
||||||
|
@ -252,7 +252,7 @@ class TestExecution:
|
|||||||
|
|
||||||
@pytest.mark.parametrize("test_type, test_value", [
|
@pytest.mark.parametrize("test_type, test_value", [
|
||||||
("StubInt", 5),
|
("StubInt", 5),
|
||||||
("StubFloat", 5.0)
|
("StubMask", 5.0)
|
||||||
])
|
])
|
||||||
def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
|
def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
@ -497,6 +497,69 @@ class TestExecution:
|
|||||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||||
assert not result.did_run(test_node), "The execution should have been cached"
|
assert not result.did_run(test_node), "The execution should have been cached"
|
||||||
|
|
||||||
|
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Create sleep nodes for each duration
|
||||||
|
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8)
|
||||||
|
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9)
|
||||||
|
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
|
||||||
|
|
||||||
|
# Add outputs to verify the execution
|
||||||
|
output1 = g.node("PreviewImage", images=sleep_node1.out(0))
|
||||||
|
output2 = g.node("PreviewImage", images=sleep_node2.out(0))
|
||||||
|
output3 = g.node("PreviewImage", images=sleep_node3.out(0))
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
result = client.run(g)
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
# The test should take around 0.4 seconds (the longest sleep duration)
|
||||||
|
# plus some overhead, but definitely less than the sum of all sleeps (0.9s)
|
||||||
|
# We'll allow for up to 0.8s total to account for overhead
|
||||||
|
assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s"
|
||||||
|
|
||||||
|
# Verify that all nodes executed
|
||||||
|
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
|
||||||
|
assert result.did_run(sleep_node2), "Sleep node 2 should have run"
|
||||||
|
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
|
||||||
|
|
||||||
|
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
# Create input images with different values
|
||||||
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
image3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Create a TestParallelSleep node that expands into multiple TestSleep nodes
|
||||||
|
parallel_sleep = g.node("TestParallelSleep",
|
||||||
|
image1=image1.out(0),
|
||||||
|
image2=image2.out(0),
|
||||||
|
image3=image3.out(0),
|
||||||
|
sleep1=0.4,
|
||||||
|
sleep2=0.5,
|
||||||
|
sleep3=0.6)
|
||||||
|
output = g.node("SaveImage", images=parallel_sleep.out(0))
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
result = client.run(g)
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Similar to the previous test, expect parallel execution of the sleep nodes
|
||||||
|
# which should complete in less than the sum of all sleeps
|
||||||
|
assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s"
|
||||||
|
|
||||||
|
# Verify the parallel sleep node executed
|
||||||
|
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
|
||||||
|
|
||||||
|
# Verify we get an image as output (blend of the three input images)
|
||||||
|
result_images = result.get_images(output)
|
||||||
|
assert len(result_images) == 1, "Should have 1 image"
|
||||||
|
# Average pixel value should be around 170 (255 * 2 // 3)
|
||||||
|
avg_value = numpy.array(result_images[0]).mean()
|
||||||
|
assert avg_value == 170, f"Image average value {avg_value} should be 170"
|
||||||
|
|
||||||
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
|
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
|
||||||
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
|
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
|
||||||
# only that one entry in the list is blocked.
|
# only that one entry in the list is blocked.
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from comfy.utils import ProgressBar
|
||||||
from .tools import VariantSupport
|
from .tools import VariantSupport
|
||||||
from comfy_execution.graph_utils import GraphBuilder
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
|
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||||
|
from comfy.comfy_types import IO
|
||||||
|
|
||||||
class TestLazyMixImages:
|
class TestLazyMixImages:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -333,6 +338,131 @@ class TestMixedExpansionReturns:
|
|||||||
"expand": g.finalize(),
|
"expand": g.finalize(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class TestSamplingInExpansion:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("MODEL",),
|
||||||
|
"clip": ("CLIP",),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
|
"steps": ("INT", {"default": 20, "min": 1, "max": 100}),
|
||||||
|
"cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 30.0}),
|
||||||
|
"prompt": ("STRING", {"multiline": True, "default": "a beautiful landscape with mountains and trees"}),
|
||||||
|
"negative_prompt": ("STRING", {"multiline": True, "default": "blurry, bad quality, worst quality"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "sampling_in_expansion"
|
||||||
|
|
||||||
|
CATEGORY = "Testing/Nodes"
|
||||||
|
|
||||||
|
def sampling_in_expansion(self, model, clip, vae, seed, steps, cfg, prompt, negative_prompt):
|
||||||
|
g = GraphBuilder()
|
||||||
|
|
||||||
|
# Create a basic image generation workflow using the input model, clip and vae
|
||||||
|
# 1. Setup text prompts using the provided CLIP model
|
||||||
|
positive_prompt = g.node("CLIPTextEncode",
|
||||||
|
text=prompt,
|
||||||
|
clip=clip)
|
||||||
|
negative_prompt = g.node("CLIPTextEncode",
|
||||||
|
text=negative_prompt,
|
||||||
|
clip=clip)
|
||||||
|
|
||||||
|
# 2. Create empty latent with specified size
|
||||||
|
empty_latent = g.node("EmptyLatentImage", width=512, height=512, batch_size=1)
|
||||||
|
|
||||||
|
# 3. Setup sampler and generate image latent
|
||||||
|
sampler = g.node("KSampler",
|
||||||
|
model=model,
|
||||||
|
positive=positive_prompt.out(0),
|
||||||
|
negative=negative_prompt.out(0),
|
||||||
|
latent_image=empty_latent.out(0),
|
||||||
|
seed=seed,
|
||||||
|
steps=steps,
|
||||||
|
cfg=cfg,
|
||||||
|
sampler_name="euler_ancestral",
|
||||||
|
scheduler="normal")
|
||||||
|
|
||||||
|
# 4. Decode latent to image using VAE
|
||||||
|
output = g.node("VAEDecode", samples=sampler.out(0), vae=vae)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"result": (output.out(0),),
|
||||||
|
"expand": g.finalize(),
|
||||||
|
}
|
||||||
|
|
||||||
|
class TestSleep(ComfyNodeABC):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": (IO.ANY, {}),
|
||||||
|
"seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = (IO.ANY,)
|
||||||
|
FUNCTION = "sleep"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
async def sleep(self, value, seconds, unique_id):
|
||||||
|
pbar = ProgressBar(seconds, node_id=unique_id)
|
||||||
|
start = time.time()
|
||||||
|
expiration = start + seconds
|
||||||
|
now = start
|
||||||
|
while now < expiration:
|
||||||
|
now = time.time()
|
||||||
|
pbar.update_absolute(now - start)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return (value,)
|
||||||
|
|
||||||
|
class TestParallelSleep(ComfyNodeABC):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image1": ("IMAGE", ),
|
||||||
|
"image2": ("IMAGE", ),
|
||||||
|
"image3": ("IMAGE", ),
|
||||||
|
"sleep1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"sleep2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"sleep3": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "parallel_sleep"
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
|
||||||
|
# Create a graph dynamically with three TestSleep nodes
|
||||||
|
g = GraphBuilder()
|
||||||
|
|
||||||
|
# Create sleep nodes for each duration and image
|
||||||
|
sleep_node1 = g.node("TestSleep", value=image1, seconds=sleep1)
|
||||||
|
sleep_node2 = g.node("TestSleep", value=image2, seconds=sleep2)
|
||||||
|
sleep_node3 = g.node("TestSleep", value=image3, seconds=sleep3)
|
||||||
|
|
||||||
|
# Blend the results using TestVariadicAverage
|
||||||
|
blend = g.node("TestVariadicAverage",
|
||||||
|
input1=sleep_node1.out(0),
|
||||||
|
input2=sleep_node2.out(0),
|
||||||
|
input3=sleep_node3.out(0))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"result": (blend.out(0),),
|
||||||
|
"expand": g.finalize(),
|
||||||
|
}
|
||||||
|
|
||||||
TEST_NODE_CLASS_MAPPINGS = {
|
TEST_NODE_CLASS_MAPPINGS = {
|
||||||
"TestLazyMixImages": TestLazyMixImages,
|
"TestLazyMixImages": TestLazyMixImages,
|
||||||
"TestVariadicAverage": TestVariadicAverage,
|
"TestVariadicAverage": TestVariadicAverage,
|
||||||
@ -345,6 +475,9 @@ TEST_NODE_CLASS_MAPPINGS = {
|
|||||||
"TestCustomValidation5": TestCustomValidation5,
|
"TestCustomValidation5": TestCustomValidation5,
|
||||||
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
|
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
|
||||||
"TestMixedExpansionReturns": TestMixedExpansionReturns,
|
"TestMixedExpansionReturns": TestMixedExpansionReturns,
|
||||||
|
"TestSamplingInExpansion": TestSamplingInExpansion,
|
||||||
|
"TestSleep": TestSleep,
|
||||||
|
"TestParallelSleep": TestParallelSleep,
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@ -359,4 +492,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"TestCustomValidation5": "Custom Validation 5",
|
"TestCustomValidation5": "Custom Validation 5",
|
||||||
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
|
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
|
||||||
"TestMixedExpansionReturns": "Mixed Expansion Returns",
|
"TestMixedExpansionReturns": "Mixed Expansion Returns",
|
||||||
|
"TestSamplingInExpansion": "Sampling In Expansion",
|
||||||
|
"TestSleep": "Test Sleep",
|
||||||
|
"TestParallelSleep": "Test Parallel Sleep",
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user