mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 13:05:07 +00:00
Display progress and result URL directly on API nodes (#8102)
* [Luma] Print download URL of successful task result directly on nodes (#177) [Veo] Print download URL of successful task result directly on nodes (#184) [Recraft] Print download URL of successful task result directly on nodes (#183) [Pixverse] Print download URL of successful task result directly on nodes (#182) [Kling] Print download URL of successful task result directly on nodes (#181) [MiniMax] Print progress text and download URL of successful task result directly on nodes (#179) [Docs] Link to docs in `API_NODE` class property type annotation comment (#178) [Ideogram] Print download URL of successful task result directly on nodes (#176) [Kling] Print download URL of successful task result directly on nodes (#181) [Veo] Print download URL of successful task result directly on nodes (#184) [Recraft] Print download URL of successful task result directly on nodes (#183) [Pixverse] Print download URL of successful task result directly on nodes (#182) [MiniMax] Print progress text and download URL of successful task result directly on nodes (#179) [Docs] Link to docs in `API_NODE` class property type annotation comment (#178) [Luma] Print download URL of successful task result directly on nodes (#177) [Ideogram] Print download URL of successful task result directly on nodes (#176) Show output URL and progress text on Pika nodes (#168) [BFL] Print download URL of successful task result directly on nodes (#175) [OpenAI ] Print download URL of successful task result directly on nodes (#174) * fix ruff errors * fix 3.10 syntax error
This commit is contained in:
@@ -103,6 +103,7 @@ from urllib.parse import urljoin, urlparse
|
||||
from pydantic import BaseModel, Field
|
||||
import uuid # For generating unique operation IDs
|
||||
|
||||
from server import PromptServer
|
||||
from comfy.cli_args import args
|
||||
from comfy import utils
|
||||
from . import request_logger
|
||||
@@ -900,6 +901,7 @@ class PollingOperation(Generic[T, R]):
|
||||
failed_statuses: list,
|
||||
status_extractor: Callable[[R], str],
|
||||
progress_extractor: Callable[[R], float] = None,
|
||||
result_url_extractor: Callable[[R], str] = None,
|
||||
request: Optional[T] = None,
|
||||
api_base: str | None = None,
|
||||
auth_token: Optional[str] = None,
|
||||
@@ -910,6 +912,8 @@ class PollingOperation(Generic[T, R]):
|
||||
max_retries: int = 3, # Max retries per individual API call
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_factor: float = 2.0,
|
||||
estimated_duration: Optional[float] = None,
|
||||
node_id: Optional[str] = None,
|
||||
):
|
||||
self.poll_endpoint = poll_endpoint
|
||||
self.request = request
|
||||
@@ -924,12 +928,15 @@ class PollingOperation(Generic[T, R]):
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_factor = retry_backoff_factor
|
||||
self.estimated_duration = estimated_duration
|
||||
|
||||
# Polling configuration
|
||||
self.status_extractor = status_extractor or (
|
||||
lambda x: getattr(x, "status", None)
|
||||
)
|
||||
self.progress_extractor = progress_extractor
|
||||
self.result_url_extractor = result_url_extractor
|
||||
self.node_id = node_id
|
||||
self.completed_statuses = completed_statuses
|
||||
self.failed_statuses = failed_statuses
|
||||
|
||||
@@ -965,6 +972,26 @@ class PollingOperation(Generic[T, R]):
|
||||
except Exception as e:
|
||||
raise Exception(f"Error during polling: {str(e)}")
|
||||
|
||||
def _display_text_on_node(self, text: str):
|
||||
"""Sends text to the client which will be displayed on the node in the UI"""
|
||||
if not self.node_id:
|
||||
return
|
||||
|
||||
PromptServer.instance.send_progress_text(text, self.node_id)
|
||||
|
||||
def _display_time_progress_on_node(self, time_completed: int):
|
||||
if not self.node_id:
|
||||
return
|
||||
|
||||
if self.estimated_duration is not None:
|
||||
estimated_time_remaining = max(
|
||||
0, int(self.estimated_duration) - int(time_completed)
|
||||
)
|
||||
message = f"Task in progress: {time_completed:.0f}s (~{estimated_time_remaining:.0f}s remaining)"
|
||||
else:
|
||||
message = f"Task in progress: {time_completed:.0f}s"
|
||||
self._display_text_on_node(message)
|
||||
|
||||
def _check_task_status(self, response: R) -> TaskStatus:
|
||||
"""Check task status using the status extractor function"""
|
||||
try:
|
||||
@@ -1031,7 +1058,15 @@ class PollingOperation(Generic[T, R]):
|
||||
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
|
||||
|
||||
if status == TaskStatus.COMPLETED:
|
||||
logging.debug("[DEBUG] Task completed successfully")
|
||||
message = "Task completed successfully"
|
||||
if self.result_url_extractor:
|
||||
result_url = self.result_url_extractor(response_obj)
|
||||
if result_url:
|
||||
message = f"Result URL: {result_url}"
|
||||
else:
|
||||
message = "Task completed successfully!"
|
||||
logging.debug(f"[DEBUG] {message}")
|
||||
self._display_text_on_node(message)
|
||||
self.final_response = response_obj
|
||||
if self.progress_extractor:
|
||||
progress.update(100)
|
||||
@@ -1047,7 +1082,10 @@ class PollingOperation(Generic[T, R]):
|
||||
logging.debug(
|
||||
f"[DEBUG] Waiting {self.poll_interval} seconds before next poll"
|
||||
)
|
||||
time.sleep(self.poll_interval)
|
||||
for i in range(int(self.poll_interval)):
|
||||
time_completed = (poll_count * self.poll_interval) + i
|
||||
self._display_time_progress_on_node(time_completed)
|
||||
time.sleep(1)
|
||||
|
||||
except (LocalNetworkError, ApiServerError) as e:
|
||||
# For network-related errors, increment error count and potentially abort
|
||||
|
Reference in New Issue
Block a user