from typing import TypedDict, Dict, Optional from typing_extensions import override from PIL import Image from enum import Enum from abc import ABC from tqdm import tqdm from comfy_execution.graph import DynamicPrompt from protocol import BinaryEventTypes class NodeState(Enum): Pending = "pending" Running = "running" Finished = "finished" Error = "error" class NodeProgressState(TypedDict): """ A class to represent the state of a node's progress. """ state: NodeState value: float max: float class ProgressHandler(ABC): """ Abstract base class for progress handlers. Progress handlers receive progress updates and display them in various ways. """ def __init__(self, name: str): self.name = name self.enabled = True def set_registry(self, registry: "ProgressRegistry"): pass def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): """Called when a node starts processing""" pass def update_handler(self, node_id: str, value: float, max_value: float, state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None): """Called when a node's progress is updated""" pass def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): """Called when a node finishes processing""" pass def reset(self): """Called when the progress registry is reset""" pass def enable(self): """Enable this handler""" self.enabled = True def disable(self): """Disable this handler""" self.enabled = False class CLIProgressHandler(ProgressHandler): """ Handler that displays progress using tqdm progress bars in the CLI. """ def __init__(self): super().__init__("cli") self.progress_bars: Dict[str, tqdm] = {} @override def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): # Create a new tqdm progress bar if node_id not in self.progress_bars: self.progress_bars[node_id] = tqdm( total=state["max"], desc=f"Node {node_id}", unit="steps", leave=True, position=len(self.progress_bars) ) @override def update_handler(self, node_id: str, value: float, max_value: float, state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None): # Handle case where start_handler wasn't called if node_id not in self.progress_bars: self.progress_bars[node_id] = tqdm( total=max_value, desc=f"Node {node_id}", unit="steps", leave=True, position=len(self.progress_bars) ) self.progress_bars[node_id].update(value) else: # Update existing progress bar if max_value != self.progress_bars[node_id].total: self.progress_bars[node_id].total = max_value # Calculate the update amount (difference from current position) current_position = self.progress_bars[node_id].n update_amount = value - current_position if update_amount > 0: self.progress_bars[node_id].update(update_amount) @override def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): # Complete and close the progress bar if it exists if node_id in self.progress_bars: # Ensure the bar shows 100% completion remaining = state["max"] - self.progress_bars[node_id].n if remaining > 0: self.progress_bars[node_id].update(remaining) self.progress_bars[node_id].close() del self.progress_bars[node_id] @override def reset(self): # Close all progress bars for bar in self.progress_bars.values(): bar.close() self.progress_bars.clear() class WebUIProgressHandler(ProgressHandler): """ Handler that sends progress updates to the WebUI via WebSockets. """ def __init__(self, server_instance): super().__init__("webui") self.server_instance = server_instance def set_registry(self, registry: "ProgressRegistry"): self.registry = registry def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]): """Send the current progress state to the client""" if self.server_instance is None: return # Only send info for non-pending nodes active_nodes = { node_id: { "value": state["value"], "max": state["max"], "state": state["state"].value, "node_id": node_id, "prompt_id": prompt_id, "display_node_id": self.registry.dynprompt.get_display_node_id(node_id), "parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id), "real_node_id": self.registry.dynprompt.get_real_node_id(node_id) } for node_id, state in nodes.items() if state["state"] != NodeState.Pending } # Send a combined progress_state message with all node states self.server_instance.send_sync("progress_state", { "prompt_id": prompt_id, "nodes": active_nodes }) @override def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): # Send progress state of all nodes if self.registry: self._send_progress_state(prompt_id, self.registry.nodes) @override def update_handler(self, node_id: str, value: float, max_value: float, state: NodeProgressState, prompt_id: str, image: Optional[Image.Image] = None): # Send progress state of all nodes if self.registry: self._send_progress_state(prompt_id, self.registry.nodes) if image: metadata = { "node_id": node_id, "prompt_id": prompt_id, "display_node_id": self.registry.dynprompt.get_display_node_id(node_id), "parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id), "real_node_id": self.registry.dynprompt.get_real_node_id(node_id) } self.server_instance.send_sync(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, (image, metadata), self.server_instance.client_id) @override def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): # Send progress state of all nodes if self.registry: self._send_progress_state(prompt_id, self.registry.nodes) class ProgressRegistry: """ Registry that maintains node progress state and notifies registered handlers. """ def __init__(self, prompt_id: str, dynprompt: DynamicPrompt): self.prompt_id = prompt_id self.dynprompt = dynprompt self.nodes: Dict[str, NodeProgressState] = {} self.handlers: Dict[str, ProgressHandler] = {} def register_handler(self, handler: ProgressHandler) -> None: """Register a progress handler""" self.handlers[handler.name] = handler def unregister_handler(self, handler_name: str) -> None: """Unregister a progress handler""" if handler_name in self.handlers: # Allow handler to clean up resources self.handlers[handler_name].reset() del self.handlers[handler_name] def enable_handler(self, handler_name: str) -> None: """Enable a progress handler""" if handler_name in self.handlers: self.handlers[handler_name].enable() def disable_handler(self, handler_name: str) -> None: """Disable a progress handler""" if handler_name in self.handlers: self.handlers[handler_name].disable() def ensure_entry(self, node_id: str) -> NodeProgressState: """Ensure a node entry exists""" if node_id not in self.nodes: self.nodes[node_id] = NodeProgressState( state = NodeState.Pending, value = 0, max = 1 ) return self.nodes[node_id] def start_progress(self, node_id: str) -> None: """Start progress tracking for a node""" entry = self.ensure_entry(node_id) entry["state"] = NodeState.Running entry["value"] = 0.0 entry["max"] = 1.0 # Notify all enabled handlers for handler in self.handlers.values(): if handler.enabled: handler.start_handler(node_id, entry, self.prompt_id) def update_progress(self, node_id: str, value: float, max_value: float, image: Optional[Image.Image]) -> None: """Update progress for a node""" entry = self.ensure_entry(node_id) entry["state"] = NodeState.Running entry["value"] = value entry["max"] = max_value # Notify all enabled handlers for handler in self.handlers.values(): if handler.enabled: handler.update_handler(node_id, value, max_value, entry, self.prompt_id, image) def finish_progress(self, node_id: str) -> None: """Finish progress tracking for a node""" entry = self.ensure_entry(node_id) entry["state"] = NodeState.Finished entry["value"] = entry["max"] # Notify all enabled handlers for handler in self.handlers.values(): if handler.enabled: handler.finish_handler(node_id, entry, self.prompt_id) def reset_handlers(self) -> None: """Reset all handlers""" for handler in self.handlers.values(): handler.reset() # Global registry instance global_progress_registry: ProgressRegistry = ProgressRegistry(prompt_id="", dynprompt=DynamicPrompt({})) def reset_progress_state(prompt_id: str, dynprompt: DynamicPrompt) -> None: global global_progress_registry # Reset existing handlers if registry exists if global_progress_registry is not None: global_progress_registry.reset_handlers() # Create new registry global_progress_registry = ProgressRegistry(prompt_id, dynprompt) def add_progress_handler(handler: ProgressHandler) -> None: handler.set_registry(global_progress_registry) global_progress_registry.register_handler(handler) def get_progress_state() -> ProgressRegistry: return global_progress_registry