Alexander Piskun bf2a1b5b1e
async API nodes (#9129)
* converted API nodes to async

* converted BFL API nodes to async

* fixed client bug; converted gemini, ideogram, minimax

* fixed client bug; converted openai nodes

* fixed client bug; converted moonvalley, pika nodes

* fixed client bug; converted kling, luma nodes

* converted pixverse, rodin nodes

* converted tripo, veo2

* converted recraft nodes

* add lost log_request_response call
2025-08-07 23:37:50 -04:00

958 lines
37 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
API Client Framework for api.comfy.org.
This module provides a flexible framework for making API requests from ComfyUI nodes.
It supports both synchronous and asynchronous API operations with proper type validation.
Key Components:
--------------
1. ApiClient - Handles HTTP requests with authentication and error handling
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
3. ApiOperation - Executes a single synchronous API operation
Usage Examples:
--------------
# Example 1: Synchronous API Operation
# ------------------------------------
# For a simple API call that returns the result immediately:
# 1. Create the API client
api_client = ApiClient(
base_url="https://api.example.com",
auth_token="your_auth_token_here",
comfy_api_key="your_comfy_api_key_here",
timeout=30.0,
verify_ssl=True
)
# 2. Define the endpoint
user_info_endpoint = ApiEndpoint(
path="/v1/users/me",
method=HttpMethod.GET,
request_model=EmptyRequest, # No request body needed
response_model=UserProfile, # Pydantic model for the response
query_params=None
)
# 3. Create the request object
request = EmptyRequest()
# 4. Create and execute the operation
operation = ApiOperation(
endpoint=user_info_endpoint,
request=request
)
user_profile = await operation.execute(client=api_client) # Returns immediately with the result
# Example 2: Asynchronous API Operation with Polling
# -------------------------------------------------
# For an API that starts a task and requires polling for completion:
# 1. Define the endpoints (initial request and polling)
generate_image_endpoint = ApiEndpoint(
path="/v1/images/generate",
method=HttpMethod.POST,
request_model=ImageGenerationRequest,
response_model=TaskCreatedResponse,
query_params=None
)
check_task_endpoint = ApiEndpoint(
path="/v1/tasks/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=ImageGenerationResult,
query_params=None
)
# 2. Create the request object
request = ImageGenerationRequest(
prompt="a beautiful sunset over mountains",
width=1024,
height=1024,
num_images=1
)
# 3. Create and execute the polling operation
operation = PollingOperation(
initial_endpoint=generate_image_endpoint,
initial_request=request,
poll_endpoint=check_task_endpoint,
task_id_field="task_id",
status_field="status",
completed_statuses=["completed"],
failed_statuses=["failed", "error"]
)
# This will make the initial request and then poll until completion
result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done
"""
from __future__ import annotations
import aiohttp
import asyncio
import logging
import io
import socket
from aiohttp.client_exceptions import ClientError, ClientResponseError
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
from enum import Enum
import json
from urllib.parse import urljoin, urlparse
from pydantic import BaseModel, Field
import uuid # For generating unique operation IDs
from server import PromptServer
from comfy.cli_args import args
from comfy import utils
from . import request_logger
T = TypeVar("T", bound=BaseModel)
R = TypeVar("R", bound=BaseModel)
P = TypeVar("P", bound=BaseModel) # For poll response
PROGRESS_BAR_MAX = 100
class NetworkError(Exception):
"""Base exception for network-related errors with diagnostic information."""
pass
class LocalNetworkError(NetworkError):
"""Exception raised when local network connectivity issues are detected."""
pass
class ApiServerError(NetworkError):
"""Exception raised when the API server is unreachable but internet is working."""
pass
class EmptyRequest(BaseModel):
"""Base class for empty request bodies.
For GET requests, fields will be sent as query parameters."""
pass
class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload")
content_type: Optional[str] = Field(
None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
)
class UploadResponse(BaseModel):
download_url: str = Field(..., description="URL to GET uploaded file")
upload_url: str = Field(..., description="URL to PUT file to upload")
class HttpMethod(str, Enum):
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
PATCH = "PATCH"
class ApiClient:
"""
Client for making HTTP requests to an API with authentication, error handling, and retry logic.
"""
def __init__(
self,
base_url: str,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
timeout: float = 3600.0,
verify_ssl: bool = True,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_status_codes: Optional[Tuple[int, ...]] = None,
session: Optional[aiohttp.ClientSession] = None,
):
self.base_url = base_url
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
self.timeout = timeout
self.verify_ssl = verify_ssl
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
# Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
# 500, 502, 503, 504 (Server Errors)
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
self._session: Optional[aiohttp.ClientSession] = session
self._owns_session = session is None # Track if we have to close it
@staticmethod
def _generate_operation_id(path: str) -> str:
"""Generates a unique operation ID for logging."""
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
@staticmethod
def _create_json_payload_args(
data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
return {
"json": data,
"headers": headers,
}
def _create_form_data_args(
self,
data: Dict[str, Any] | None,
files: Dict[str, Any] | None,
headers: Optional[Dict[str, str]] = None,
multipart_parser: Callable | None = None,
) -> Dict[str, Any]:
if headers and "Content-Type" in headers:
del headers["Content-Type"]
if multipart_parser and data:
data = multipart_parser(data)
form = aiohttp.FormData(default_to_multipart=True)
if data: # regular text fields
for k, v in data.items():
if v is None:
continue # aiohttp fails to serialize "None" values
# aiohttp expects strings or bytes; convert enums etc.
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
if files:
file_iter = files if isinstance(files, list) else files.items()
for field_name, file_obj in file_iter:
if file_obj is None:
continue # aiohttp fails to serialize "None" values
# file_obj can be (filename, bytes/io.BytesIO, content_type) tuple
if isinstance(file_obj, tuple):
filename, file_value, content_type = self._unpack_tuple(file_obj)
else:
file_value = file_obj
filename = getattr(file_obj, "name", field_name)
content_type = "application/octet-stream"
form.add_field(
name=field_name,
value=file_value,
filename=filename,
content_type=content_type,
)
return {"data": form, "headers": headers or {}}
@staticmethod
def _create_urlencoded_form_data_args(
data: Dict[str, Any],
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
headers = headers or {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
return {
"data": data,
"headers": headers,
}
def get_headers(self) -> Dict[str, str]:
"""Get headers for API requests, including authentication if available"""
headers = {"Content-Type": "application/json", "Accept": "application/json"}
if self.auth_token:
headers["Authorization"] = f"Bearer {self.auth_token}"
elif self.comfy_api_key:
headers["X-API-KEY"] = self.comfy_api_key
return headers
async def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
"""
Check connectivity to determine if network issues are local or server-related.
Args:
target_url: URL to check connectivity to
Returns:
Dictionary with connectivity status details
"""
results = {
"internet_accessible": False,
"api_accessible": False,
"is_local_issue": False,
"is_api_issue": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
async with aiohttp.ClientSession(timeout=timeout) as session:
try:
async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp:
results["internet_accessible"] = resp.status < 500
except (ClientError, asyncio.TimeoutError, socket.gaierror):
results["is_local_issue"] = True
return results # cannot reach the internet early exit
# Now check API health endpoint
parsed = urlparse(target_url)
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
try:
async with session.get(health_url, ssl=self.verify_ssl) as resp:
results["api_accessible"] = resp.status < 500
except ClientError:
pass # leave as False
results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"]
return results
async def request(
self,
method: str,
path: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
headers: Optional[Dict[str, str]] = None,
content_type: str = "application/json",
multipart_parser: Callable | None = None,
retry_count: int = 0, # Used internally for tracking retries
) -> Dict[str, Any]:
"""
Make an HTTP request to the API with automatic retries for transient errors.
Args:
method: HTTP method (GET, POST, etc.)
path: API endpoint path (will be joined with base_url)
params: Query parameters
data: body data
files: Files to upload
headers: Additional headers
content_type: Content type of the request. Defaults to application/json.
retry_count: Internal parameter for tracking retries, do not set manually
Returns:
Parsed JSON response
Raises:
LocalNetworkError: If local network connectivity issues are detected
ApiServerError: If the API server is unreachable but internet is working
Exception: For other request failures
"""
# Build full URL and merge headers
relative_path = path.lstrip("/")
url = urljoin(self.base_url, relative_path)
self._check_auth(self.auth_token, self.comfy_api_key)
request_headers = self.get_headers()
if headers:
request_headers.update(headers)
if files:
request_headers.pop("Content-Type", None)
if params:
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
logging.debug(f"[DEBUG] Files: {files}")
logging.debug(f"[DEBUG] Params: {params}")
logging.debug(f"[DEBUG] Data: {data}")
if content_type == "application/x-www-form-urlencoded":
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
elif content_type == "multipart/form-data":
payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser)
else:
payload_args = self._create_json_payload_args(data, request_headers)
operation_id = self._generate_operation_id(path)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=request_headers,
request_params=params,
request_data=data if content_type == "application/json" else "[form-data or other]",
)
session = await self._get_session()
try:
async with session.request(
method,
url,
params=params,
ssl=self.verify_ssl,
**payload_args,
) as resp:
if resp.status >= 400:
try:
error_data = await resp.json()
except (aiohttp.ContentTypeError, json.JSONDecodeError):
error_data = await resp.text()
return await self._handle_http_error(
ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data),
operation_id,
method,
url,
params,
data,
files,
headers,
content_type,
multipart_parser,
retry_count=retry_count,
response_content=error_data,
)
# Success parse JSON (safely) and log
try:
payload = await resp.json()
response_content_to_log = payload
except (aiohttp.ContentTypeError, json.JSONDecodeError):
payload = {}
response_content_to_log = await resp.text()
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=response_content_to_log,
)
return payload
except (ClientError, asyncio.TimeoutError, socket.gaierror) as e:
# Treat as *connection* problem optionally retry, else escalate
if retry_count < self.max_retries:
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1,
self.max_retries, str(e))
await asyncio.sleep(delay)
return await self.request(
method,
path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
)
# One final connectivity check for diagnostics
connectivity = await self._check_connectivity(self.base_url)
if connectivity["is_local_issue"]:
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
raise ApiServerError(
f"The API server at {self.base_url} is currently unreachable. "
f"The service may be experiencing issues. Please try again later."
) from e
@staticmethod
def _check_auth(auth_token, comfy_api_key):
"""Verify that an auth token is present or comfy_api_key is present"""
if auth_token is None and comfy_api_key is None:
raise Exception("Unauthorized: Please login first to use this node.")
return auth_token or comfy_api_key
@staticmethod
async def upload_file(
upload_url: str,
file: io.BytesIO | str,
content_type: str | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
) -> aiohttp.ClientResponse:
"""Upload a file to the API with retry logic.
Args:
upload_url: The URL to upload to
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
content_type: Optional mime type to set for the upload
max_retries: Maximum number of retry attempts
retry_delay: Initial delay between retries in seconds
retry_backoff_factor: Multiplier for the delay after each retry
"""
headers: Dict[str, str] = {}
skip_auto_headers: set[str] = set()
if content_type:
headers["Content-Type"] = content_type
else:
# tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status.
skip_auto_headers.add("Content-Type")
# Extract file bytes
if isinstance(file, io.BytesIO):
file.seek(0)
data = file.read()
elif isinstance(file, str):
with open(file, "rb") as f:
data = f.read()
else:
raise ValueError("File must be BytesIO or str path")
operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}"
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers,
request_data=f"[File data {len(data)} bytes]",
)
delay = retry_delay
for attempt in range(max_retries + 1):
try:
timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.put(
upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers,
) as resp:
resp.raise_for_status()
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content="File uploaded successfully.",
)
return resp
except (ClientError, asyncio.TimeoutError) as e:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=e.status if hasattr(e, "status") else None,
response_headers=dict(e.headers) if getattr(e, "headers") else None,
response_content=None,
error_message=f"{type(e).__name__}: {str(e)}",
)
if attempt < max_retries:
logging.warning(
"Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e)
)
await asyncio.sleep(delay)
delay *= retry_backoff_factor
else:
raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e
async def _handle_http_error(
self,
exc: ClientResponseError,
operation_id: str,
*req_meta,
retry_count: int,
response_content: dict | str = "",
) -> Dict[str, Any]:
status_code = exc.status
if status_code == 401:
user_friendly = "Unauthorized: Please login first to use this node."
elif status_code == 402:
user_friendly = "Payment Required: Please add credits to your account to use this node."
elif status_code == 409:
user_friendly = "There is a problem with your account. Please contact support@comfy.org."
elif status_code == 429:
user_friendly = "Rate Limit Exceeded: Please try again later."
else:
if isinstance(response_content, dict):
if "error" in response_content and "message" in response_content["error"]:
user_friendly = f"API Error: {response_content['error']['message']}"
if "type" in response_content["error"]:
user_friendly += f" (Type: {response_content['error']['type']})"
else: # Handle cases where error is just a JSON dict with unknown format
user_friendly = f"API Error: {json.dumps(response_content)}"
else:
if len(response_content) < 200: # Arbitrary limit for display
user_friendly = f"API Error (raw): {response_content}"
else:
user_friendly = f"API Error (raw, status {response_content})"
request_logger.log_request_response(
operation_id=operation_id,
request_method=req_meta[0],
request_url=req_meta[1],
response_status_code=exc.status,
response_headers=dict(req_meta[5]) if req_meta[5] else None,
response_content=response_content,
error_message=f"HTTP Error {exc.status}",
)
logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})")
if response_content:
logging.debug(f"[DEBUG] Response content: {response_content}")
# Retry if eligible
if status_code in self.retry_status_codes and retry_count < self.max_retries:
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
"HTTP error %s. Retrying in %.2fs (%s/%s)",
status_code,
delay,
retry_count + 1,
self.max_retries,
)
await asyncio.sleep(delay)
return await self.request(
req_meta[0], # method
req_meta[1].replace(self.base_url, ""), # path
params=req_meta[2],
data=req_meta[3],
files=req_meta[4],
headers=req_meta[5],
content_type=req_meta[6],
multipart_parser=req_meta[7],
retry_count=retry_count + 1,
)
raise Exception(user_friendly) from exc
@staticmethod
def _unpack_tuple(t):
"""Helper to normalise (filename, file, content_type) tuples."""
if len(t) == 3:
return t
elif len(t) == 2:
return t[0], t[1], "application/octet-stream"
else:
raise ValueError("files tuple must be (filename, file[, content_type])")
async def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
timeout = aiohttp.ClientTimeout(total=self.timeout)
self._session = aiohttp.ClientSession(timeout=timeout)
self._owns_session = True
return self._session
async def close(self) -> None:
if self._owns_session and self._session and not self._session.closed:
await self._session.close()
async def __aenter__(self) -> "ApiClient":
"""Allow usage as asynccontextmanager ensures clean teardown"""
return self
async def __aexit__(self, exc_type, exc, tb):
await self.close()
class ApiEndpoint(Generic[T, R]):
"""Defines an API endpoint with its request and response types"""
def __init__(
self,
path: str,
method: HttpMethod,
request_model: Type[T],
response_model: Type[R],
query_params: Optional[Dict[str, Any]] = None,
):
"""Initialize an API endpoint definition.
Args:
path: The URL path for this endpoint, can include placeholders like {id}
method: The HTTP method to use (GET, POST, etc.)
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
query_params: Optional dictionary of query parameters to include in the request
"""
self.path = path
self.method = method
self.request_model = request_model
self.response_model = response_model
self.query_params = query_params or {}
class SynchronousOperation(Generic[T, R]):
"""Represents a single synchronous API operation."""
def __init__(
self,
endpoint: ApiEndpoint[T, R],
request: T,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None,
timeout: float = 604800.0,
verify_ssl: bool = True,
content_type: str = "application/json",
multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
) -> None:
self.endpoint = endpoint
self.request = request
self.files = files
self.api_base: str = api_base or args.comfy_api_base
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
if auth_kwargs is not None:
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.timeout = timeout
self.verify_ssl = verify_ssl
self.content_type = content_type
self.multipart_parser = multipart_parser
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
async def execute(self, client: Optional[ApiClient] = None) -> R:
owns_client = client is None
if owns_client:
client = ApiClient(
base_url=self.api_base,
auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
timeout=self.timeout,
verify_ssl=self.verify_ssl,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
retry_backoff_factor=self.retry_backoff_factor,
)
try:
request_dict: Optional[Dict[str, Any]]
if isinstance(self.request, EmptyRequest):
request_dict = None
else:
request_dict = self.request.model_dump(exclude_none=True)
for k, v in list(request_dict.items()):
if isinstance(v, Enum):
request_dict[k] = v.value
logging.debug(
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
)
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
response_json = await client.request(
self.endpoint.method.value,
self.endpoint.path,
params=self.endpoint.query_params,
data=request_dict,
files=self.files,
content_type=self.content_type,
multipart_parser=self.multipart_parser,
)
logging.debug("=" * 50)
logging.debug("[DEBUG] RESPONSE DETAILS:")
logging.debug("[DEBUG] Status Code: 200 (Success)")
logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}")
logging.debug("=" * 50)
parsed_response = self.endpoint.response_model.model_validate(response_json)
logging.debug(f"[DEBUG] Parsed Response: {parsed_response}")
return parsed_response
finally:
if owns_client:
await client.close()
class TaskStatus(str, Enum):
"""Enum for task status values"""
COMPLETED = "completed"
FAILED = "failed"
PENDING = "pending"
class PollingOperation(Generic[T, R]):
"""Represents an asynchronous API operation that requires polling for completion."""
def __init__(
self,
poll_endpoint: ApiEndpoint[EmptyRequest, R],
completed_statuses: list[str],
failed_statuses: list[str],
status_extractor: Callable[[R], str],
progress_extractor: Callable[[R], float] | None = None,
result_url_extractor: Callable[[R], str] | None = None,
request: Optional[T] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
max_retries: int = 3, # Max retries per individual API call
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
estimated_duration: Optional[float] = None,
node_id: Optional[str] = None,
) -> None:
self.poll_endpoint = poll_endpoint
self.request = request
self.api_base: str = api_base or args.comfy_api_base
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
if auth_kwargs is not None:
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.poll_interval = poll_interval
self.max_poll_attempts = max_poll_attempts
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
self.estimated_duration = estimated_duration
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None))
self.progress_extractor = progress_extractor
self.result_url_extractor = result_url_extractor
self.node_id = node_id
self.completed_statuses = completed_statuses
self.failed_statuses = failed_statuses
self.final_response: Optional[R] = None
async def execute(self, client: Optional[ApiClient] = None) -> R:
owns_client = client is None
if owns_client:
client = ApiClient(
base_url=self.api_base,
auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
retry_backoff_factor=self.retry_backoff_factor,
)
try:
return await self._poll_until_complete(client)
finally:
if owns_client:
await client.close()
def _display_text_on_node(self, text: str):
if not self.node_id:
return
PromptServer.instance.send_progress_text(text, self.node_id)
def _display_time_progress_on_node(self, time_completed: int | float):
if not self.node_id:
return
if self.estimated_duration is not None:
remaining = max(0, int(self.estimated_duration) - time_completed)
message = f"Task in progress: {time_completed}s (~{remaining}s remaining)"
else:
message = f"Task in progress: {time_completed}s"
self._display_text_on_node(message)
def _check_task_status(self, response: R) -> TaskStatus:
try:
status = self.status_extractor(response)
if status in self.completed_statuses:
return TaskStatus.COMPLETED
if status in self.failed_statuses:
return TaskStatus.FAILED
return TaskStatus.PENDING
except Exception as e:
logging.error("Error extracting status: %s", e)
return TaskStatus.PENDING
async def _poll_until_complete(self, client: ApiClient) -> R:
"""Poll until the task is complete"""
consecutive_errors = 0
max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
if self.progress_extractor:
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
status = TaskStatus.PENDING
for poll_count in range(1, self.max_poll_attempts + 1):
try:
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
request_dict = (
None if self.request is None else self.request.model_dump(exclude_none=True)
)
if poll_count == 1:
logging.debug(
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}"
)
logging.debug(
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}"
)
# Query task status
resp = await client.request(
self.poll_endpoint.method.value,
self.poll_endpoint.path,
params=self.poll_endpoint.query_params,
data=request_dict,
)
consecutive_errors = 0 # reset on success
response_obj: R = self.poll_endpoint.response_model.model_validate(resp)
# Check if task is complete
status = self._check_task_status(response_obj)
logging.debug(f"[DEBUG] Task Status: {status}")
# If progress extractor is provided, extract progress
if self.progress_extractor:
new_progress = self.progress_extractor(response_obj)
if new_progress is not None:
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
if status == TaskStatus.COMPLETED:
message = "Task completed successfully"
if self.result_url_extractor:
result_url = self.result_url_extractor(response_obj)
if result_url:
message = f"Result URL: {result_url}"
logging.debug(f"[DEBUG] {message}")
self._display_text_on_node(message)
self.final_response = response_obj
if self.progress_extractor:
progress.update(100)
return self.final_response
if status == TaskStatus.FAILED:
message = f"Task failed: {json.dumps(resp)}"
logging.error(f"[DEBUG] {message}")
raise Exception(message)
logging.debug("[DEBUG] Task still pending, continuing to poll...")
# Task pending wait
for i in range(int(self.poll_interval)):
self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i)
await asyncio.sleep(1)
except (LocalNetworkError, ApiServerError, NetworkError) as e:
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
raise Exception(
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
) from e
logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e))
await asyncio.sleep(self.poll_interval)
except Exception as e:
# For other errors, increment count and potentially abort
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
raise Exception(
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
) from e
logging.error(f"[DEBUG] Polling error: {str(e)}")
logging.warning(
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
f"Will retry in {self.poll_interval} seconds."
)
await asyncio.sleep(self.poll_interval)
# If we've exhausted all polling attempts
raise Exception(
f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). "
"The operation may still be running on the server but is taking longer than expected."
)