mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-08-13 16:46:34 +00:00
* converted API nodes to async * converted BFL API nodes to async * fixed client bug; converted gemini, ideogram, minimax * fixed client bug; converted openai nodes * fixed client bug; converted moonvalley, pika nodes * fixed client bug; converted kling, luma nodes * converted pixverse, rodin nodes * converted tripo, veo2 * converted recraft nodes * add lost log_request_response call
958 lines
37 KiB
Python
958 lines
37 KiB
Python
"""
|
||
API Client Framework for api.comfy.org.
|
||
|
||
This module provides a flexible framework for making API requests from ComfyUI nodes.
|
||
It supports both synchronous and asynchronous API operations with proper type validation.
|
||
|
||
Key Components:
|
||
--------------
|
||
1. ApiClient - Handles HTTP requests with authentication and error handling
|
||
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
|
||
3. ApiOperation - Executes a single synchronous API operation
|
||
|
||
Usage Examples:
|
||
--------------
|
||
|
||
# Example 1: Synchronous API Operation
|
||
# ------------------------------------
|
||
# For a simple API call that returns the result immediately:
|
||
|
||
# 1. Create the API client
|
||
api_client = ApiClient(
|
||
base_url="https://api.example.com",
|
||
auth_token="your_auth_token_here",
|
||
comfy_api_key="your_comfy_api_key_here",
|
||
timeout=30.0,
|
||
verify_ssl=True
|
||
)
|
||
|
||
# 2. Define the endpoint
|
||
user_info_endpoint = ApiEndpoint(
|
||
path="/v1/users/me",
|
||
method=HttpMethod.GET,
|
||
request_model=EmptyRequest, # No request body needed
|
||
response_model=UserProfile, # Pydantic model for the response
|
||
query_params=None
|
||
)
|
||
|
||
# 3. Create the request object
|
||
request = EmptyRequest()
|
||
|
||
# 4. Create and execute the operation
|
||
operation = ApiOperation(
|
||
endpoint=user_info_endpoint,
|
||
request=request
|
||
)
|
||
user_profile = await operation.execute(client=api_client) # Returns immediately with the result
|
||
|
||
|
||
# Example 2: Asynchronous API Operation with Polling
|
||
# -------------------------------------------------
|
||
# For an API that starts a task and requires polling for completion:
|
||
|
||
# 1. Define the endpoints (initial request and polling)
|
||
generate_image_endpoint = ApiEndpoint(
|
||
path="/v1/images/generate",
|
||
method=HttpMethod.POST,
|
||
request_model=ImageGenerationRequest,
|
||
response_model=TaskCreatedResponse,
|
||
query_params=None
|
||
)
|
||
|
||
check_task_endpoint = ApiEndpoint(
|
||
path="/v1/tasks/{task_id}",
|
||
method=HttpMethod.GET,
|
||
request_model=EmptyRequest,
|
||
response_model=ImageGenerationResult,
|
||
query_params=None
|
||
)
|
||
|
||
# 2. Create the request object
|
||
request = ImageGenerationRequest(
|
||
prompt="a beautiful sunset over mountains",
|
||
width=1024,
|
||
height=1024,
|
||
num_images=1
|
||
)
|
||
|
||
# 3. Create and execute the polling operation
|
||
operation = PollingOperation(
|
||
initial_endpoint=generate_image_endpoint,
|
||
initial_request=request,
|
||
poll_endpoint=check_task_endpoint,
|
||
task_id_field="task_id",
|
||
status_field="status",
|
||
completed_statuses=["completed"],
|
||
failed_statuses=["failed", "error"]
|
||
)
|
||
|
||
# This will make the initial request and then poll until completion
|
||
result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
import aiohttp
|
||
import asyncio
|
||
import logging
|
||
import io
|
||
import socket
|
||
from aiohttp.client_exceptions import ClientError, ClientResponseError
|
||
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
|
||
from enum import Enum
|
||
import json
|
||
from urllib.parse import urljoin, urlparse
|
||
from pydantic import BaseModel, Field
|
||
import uuid # For generating unique operation IDs
|
||
|
||
from server import PromptServer
|
||
from comfy.cli_args import args
|
||
from comfy import utils
|
||
from . import request_logger
|
||
|
||
T = TypeVar("T", bound=BaseModel)
|
||
R = TypeVar("R", bound=BaseModel)
|
||
P = TypeVar("P", bound=BaseModel) # For poll response
|
||
|
||
PROGRESS_BAR_MAX = 100
|
||
|
||
|
||
class NetworkError(Exception):
|
||
"""Base exception for network-related errors with diagnostic information."""
|
||
pass
|
||
|
||
|
||
class LocalNetworkError(NetworkError):
|
||
"""Exception raised when local network connectivity issues are detected."""
|
||
pass
|
||
|
||
|
||
class ApiServerError(NetworkError):
|
||
"""Exception raised when the API server is unreachable but internet is working."""
|
||
pass
|
||
|
||
|
||
class EmptyRequest(BaseModel):
|
||
"""Base class for empty request bodies.
|
||
For GET requests, fields will be sent as query parameters."""
|
||
|
||
pass
|
||
|
||
|
||
class UploadRequest(BaseModel):
|
||
file_name: str = Field(..., description="Filename to upload")
|
||
content_type: Optional[str] = Field(
|
||
None,
|
||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||
)
|
||
|
||
|
||
class UploadResponse(BaseModel):
|
||
download_url: str = Field(..., description="URL to GET uploaded file")
|
||
upload_url: str = Field(..., description="URL to PUT file to upload")
|
||
|
||
|
||
class HttpMethod(str, Enum):
|
||
GET = "GET"
|
||
POST = "POST"
|
||
PUT = "PUT"
|
||
DELETE = "DELETE"
|
||
PATCH = "PATCH"
|
||
|
||
|
||
class ApiClient:
|
||
"""
|
||
Client for making HTTP requests to an API with authentication, error handling, and retry logic.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
base_url: str,
|
||
auth_token: Optional[str] = None,
|
||
comfy_api_key: Optional[str] = None,
|
||
timeout: float = 3600.0,
|
||
verify_ssl: bool = True,
|
||
max_retries: int = 3,
|
||
retry_delay: float = 1.0,
|
||
retry_backoff_factor: float = 2.0,
|
||
retry_status_codes: Optional[Tuple[int, ...]] = None,
|
||
session: Optional[aiohttp.ClientSession] = None,
|
||
):
|
||
self.base_url = base_url
|
||
self.auth_token = auth_token
|
||
self.comfy_api_key = comfy_api_key
|
||
self.timeout = timeout
|
||
self.verify_ssl = verify_ssl
|
||
self.max_retries = max_retries
|
||
self.retry_delay = retry_delay
|
||
self.retry_backoff_factor = retry_backoff_factor
|
||
# Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
|
||
# 500, 502, 503, 504 (Server Errors)
|
||
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
|
||
self._session: Optional[aiohttp.ClientSession] = session
|
||
self._owns_session = session is None # Track if we have to close it
|
||
|
||
@staticmethod
|
||
def _generate_operation_id(path: str) -> str:
|
||
"""Generates a unique operation ID for logging."""
|
||
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
|
||
|
||
@staticmethod
|
||
def _create_json_payload_args(
|
||
data: Optional[Dict[str, Any]] = None,
|
||
headers: Optional[Dict[str, str]] = None,
|
||
) -> Dict[str, Any]:
|
||
return {
|
||
"json": data,
|
||
"headers": headers,
|
||
}
|
||
|
||
def _create_form_data_args(
|
||
self,
|
||
data: Dict[str, Any] | None,
|
||
files: Dict[str, Any] | None,
|
||
headers: Optional[Dict[str, str]] = None,
|
||
multipart_parser: Callable | None = None,
|
||
) -> Dict[str, Any]:
|
||
if headers and "Content-Type" in headers:
|
||
del headers["Content-Type"]
|
||
|
||
if multipart_parser and data:
|
||
data = multipart_parser(data)
|
||
|
||
form = aiohttp.FormData(default_to_multipart=True)
|
||
if data: # regular text fields
|
||
for k, v in data.items():
|
||
if v is None:
|
||
continue # aiohttp fails to serialize "None" values
|
||
# aiohttp expects strings or bytes; convert enums etc.
|
||
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
|
||
|
||
if files:
|
||
file_iter = files if isinstance(files, list) else files.items()
|
||
for field_name, file_obj in file_iter:
|
||
if file_obj is None:
|
||
continue # aiohttp fails to serialize "None" values
|
||
# file_obj can be (filename, bytes/io.BytesIO, content_type) tuple
|
||
if isinstance(file_obj, tuple):
|
||
filename, file_value, content_type = self._unpack_tuple(file_obj)
|
||
else:
|
||
file_value = file_obj
|
||
filename = getattr(file_obj, "name", field_name)
|
||
content_type = "application/octet-stream"
|
||
|
||
form.add_field(
|
||
name=field_name,
|
||
value=file_value,
|
||
filename=filename,
|
||
content_type=content_type,
|
||
)
|
||
return {"data": form, "headers": headers or {}}
|
||
|
||
@staticmethod
|
||
def _create_urlencoded_form_data_args(
|
||
data: Dict[str, Any],
|
||
headers: Optional[Dict[str, str]] = None,
|
||
) -> Dict[str, Any]:
|
||
headers = headers or {}
|
||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||
return {
|
||
"data": data,
|
||
"headers": headers,
|
||
}
|
||
|
||
def get_headers(self) -> Dict[str, str]:
|
||
"""Get headers for API requests, including authentication if available"""
|
||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||
|
||
if self.auth_token:
|
||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||
elif self.comfy_api_key:
|
||
headers["X-API-KEY"] = self.comfy_api_key
|
||
|
||
return headers
|
||
|
||
async def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
|
||
"""
|
||
Check connectivity to determine if network issues are local or server-related.
|
||
|
||
Args:
|
||
target_url: URL to check connectivity to
|
||
|
||
Returns:
|
||
Dictionary with connectivity status details
|
||
"""
|
||
results = {
|
||
"internet_accessible": False,
|
||
"api_accessible": False,
|
||
"is_local_issue": False,
|
||
"is_api_issue": False,
|
||
}
|
||
timeout = aiohttp.ClientTimeout(total=5.0)
|
||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||
try:
|
||
async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp:
|
||
results["internet_accessible"] = resp.status < 500
|
||
except (ClientError, asyncio.TimeoutError, socket.gaierror):
|
||
results["is_local_issue"] = True
|
||
return results # cannot reach the internet – early exit
|
||
|
||
# Now check API health endpoint
|
||
parsed = urlparse(target_url)
|
||
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
||
try:
|
||
async with session.get(health_url, ssl=self.verify_ssl) as resp:
|
||
results["api_accessible"] = resp.status < 500
|
||
except ClientError:
|
||
pass # leave as False
|
||
|
||
results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"]
|
||
return results
|
||
|
||
async def request(
|
||
self,
|
||
method: str,
|
||
path: str,
|
||
params: Optional[Dict[str, Any]] = None,
|
||
data: Optional[Dict[str, Any]] = None,
|
||
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
|
||
headers: Optional[Dict[str, str]] = None,
|
||
content_type: str = "application/json",
|
||
multipart_parser: Callable | None = None,
|
||
retry_count: int = 0, # Used internally for tracking retries
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
Make an HTTP request to the API with automatic retries for transient errors.
|
||
|
||
Args:
|
||
method: HTTP method (GET, POST, etc.)
|
||
path: API endpoint path (will be joined with base_url)
|
||
params: Query parameters
|
||
data: body data
|
||
files: Files to upload
|
||
headers: Additional headers
|
||
content_type: Content type of the request. Defaults to application/json.
|
||
retry_count: Internal parameter for tracking retries, do not set manually
|
||
|
||
Returns:
|
||
Parsed JSON response
|
||
|
||
Raises:
|
||
LocalNetworkError: If local network connectivity issues are detected
|
||
ApiServerError: If the API server is unreachable but internet is working
|
||
Exception: For other request failures
|
||
"""
|
||
|
||
# Build full URL and merge headers
|
||
relative_path = path.lstrip("/")
|
||
url = urljoin(self.base_url, relative_path)
|
||
self._check_auth(self.auth_token, self.comfy_api_key)
|
||
|
||
request_headers = self.get_headers()
|
||
if headers:
|
||
request_headers.update(headers)
|
||
if files:
|
||
request_headers.pop("Content-Type", None)
|
||
if params:
|
||
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
|
||
|
||
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
|
||
logging.debug(f"[DEBUG] Files: {files}")
|
||
logging.debug(f"[DEBUG] Params: {params}")
|
||
logging.debug(f"[DEBUG] Data: {data}")
|
||
|
||
if content_type == "application/x-www-form-urlencoded":
|
||
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
|
||
elif content_type == "multipart/form-data":
|
||
payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser)
|
||
else:
|
||
payload_args = self._create_json_payload_args(data, request_headers)
|
||
|
||
operation_id = self._generate_operation_id(path)
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
request_headers=request_headers,
|
||
request_params=params,
|
||
request_data=data if content_type == "application/json" else "[form-data or other]",
|
||
)
|
||
|
||
session = await self._get_session()
|
||
try:
|
||
async with session.request(
|
||
method,
|
||
url,
|
||
params=params,
|
||
ssl=self.verify_ssl,
|
||
**payload_args,
|
||
) as resp:
|
||
if resp.status >= 400:
|
||
try:
|
||
error_data = await resp.json()
|
||
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
||
error_data = await resp.text()
|
||
|
||
return await self._handle_http_error(
|
||
ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data),
|
||
operation_id,
|
||
method,
|
||
url,
|
||
params,
|
||
data,
|
||
files,
|
||
headers,
|
||
content_type,
|
||
multipart_parser,
|
||
retry_count=retry_count,
|
||
response_content=error_data,
|
||
)
|
||
|
||
# Success – parse JSON (safely) and log
|
||
try:
|
||
payload = await resp.json()
|
||
response_content_to_log = payload
|
||
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
||
payload = {}
|
||
response_content_to_log = await resp.text()
|
||
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=method,
|
||
request_url=url,
|
||
response_status_code=resp.status,
|
||
response_headers=dict(resp.headers),
|
||
response_content=response_content_to_log,
|
||
)
|
||
return payload
|
||
|
||
except (ClientError, asyncio.TimeoutError, socket.gaierror) as e:
|
||
# Treat as *connection* problem – optionally retry, else escalate
|
||
if retry_count < self.max_retries:
|
||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
||
logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1,
|
||
self.max_retries, str(e))
|
||
await asyncio.sleep(delay)
|
||
return await self.request(
|
||
method,
|
||
path,
|
||
params=params,
|
||
data=data,
|
||
files=files,
|
||
headers=headers,
|
||
content_type=content_type,
|
||
multipart_parser=multipart_parser,
|
||
retry_count=retry_count + 1,
|
||
)
|
||
# One final connectivity check for diagnostics
|
||
connectivity = await self._check_connectivity(self.base_url)
|
||
if connectivity["is_local_issue"]:
|
||
raise LocalNetworkError(
|
||
"Unable to connect to the API server due to local network issues. "
|
||
"Please check your internet connection and try again."
|
||
) from e
|
||
raise ApiServerError(
|
||
f"The API server at {self.base_url} is currently unreachable. "
|
||
f"The service may be experiencing issues. Please try again later."
|
||
) from e
|
||
|
||
@staticmethod
|
||
def _check_auth(auth_token, comfy_api_key):
|
||
"""Verify that an auth token is present or comfy_api_key is present"""
|
||
if auth_token is None and comfy_api_key is None:
|
||
raise Exception("Unauthorized: Please login first to use this node.")
|
||
return auth_token or comfy_api_key
|
||
|
||
@staticmethod
|
||
async def upload_file(
|
||
upload_url: str,
|
||
file: io.BytesIO | str,
|
||
content_type: str | None = None,
|
||
max_retries: int = 3,
|
||
retry_delay: float = 1.0,
|
||
retry_backoff_factor: float = 2.0,
|
||
) -> aiohttp.ClientResponse:
|
||
"""Upload a file to the API with retry logic.
|
||
|
||
Args:
|
||
upload_url: The URL to upload to
|
||
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
|
||
content_type: Optional mime type to set for the upload
|
||
max_retries: Maximum number of retry attempts
|
||
retry_delay: Initial delay between retries in seconds
|
||
retry_backoff_factor: Multiplier for the delay after each retry
|
||
"""
|
||
headers: Dict[str, str] = {}
|
||
skip_auto_headers: set[str] = set()
|
||
if content_type:
|
||
headers["Content-Type"] = content_type
|
||
else:
|
||
# tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status.
|
||
skip_auto_headers.add("Content-Type")
|
||
|
||
# Extract file bytes
|
||
if isinstance(file, io.BytesIO):
|
||
file.seek(0)
|
||
data = file.read()
|
||
elif isinstance(file, str):
|
||
with open(file, "rb") as f:
|
||
data = f.read()
|
||
else:
|
||
raise ValueError("File must be BytesIO or str path")
|
||
|
||
operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}"
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method="PUT",
|
||
request_url=upload_url,
|
||
request_headers=headers,
|
||
request_data=f"[File data {len(data)} bytes]",
|
||
)
|
||
|
||
delay = retry_delay
|
||
for attempt in range(max_retries + 1):
|
||
try:
|
||
timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts
|
||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||
async with session.put(
|
||
upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers,
|
||
) as resp:
|
||
resp.raise_for_status()
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method="PUT",
|
||
request_url=upload_url,
|
||
response_status_code=resp.status,
|
||
response_headers=dict(resp.headers),
|
||
response_content="File uploaded successfully.",
|
||
)
|
||
return resp
|
||
except (ClientError, asyncio.TimeoutError) as e:
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method="PUT",
|
||
request_url=upload_url,
|
||
response_status_code=e.status if hasattr(e, "status") else None,
|
||
response_headers=dict(e.headers) if getattr(e, "headers") else None,
|
||
response_content=None,
|
||
error_message=f"{type(e).__name__}: {str(e)}",
|
||
)
|
||
if attempt < max_retries:
|
||
logging.warning(
|
||
"Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e)
|
||
)
|
||
await asyncio.sleep(delay)
|
||
delay *= retry_backoff_factor
|
||
else:
|
||
raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e
|
||
|
||
async def _handle_http_error(
|
||
self,
|
||
exc: ClientResponseError,
|
||
operation_id: str,
|
||
*req_meta,
|
||
retry_count: int,
|
||
response_content: dict | str = "",
|
||
) -> Dict[str, Any]:
|
||
status_code = exc.status
|
||
if status_code == 401:
|
||
user_friendly = "Unauthorized: Please login first to use this node."
|
||
elif status_code == 402:
|
||
user_friendly = "Payment Required: Please add credits to your account to use this node."
|
||
elif status_code == 409:
|
||
user_friendly = "There is a problem with your account. Please contact support@comfy.org."
|
||
elif status_code == 429:
|
||
user_friendly = "Rate Limit Exceeded: Please try again later."
|
||
else:
|
||
if isinstance(response_content, dict):
|
||
if "error" in response_content and "message" in response_content["error"]:
|
||
user_friendly = f"API Error: {response_content['error']['message']}"
|
||
if "type" in response_content["error"]:
|
||
user_friendly += f" (Type: {response_content['error']['type']})"
|
||
else: # Handle cases where error is just a JSON dict with unknown format
|
||
user_friendly = f"API Error: {json.dumps(response_content)}"
|
||
else:
|
||
if len(response_content) < 200: # Arbitrary limit for display
|
||
user_friendly = f"API Error (raw): {response_content}"
|
||
else:
|
||
user_friendly = f"API Error (raw, status {response_content})"
|
||
|
||
request_logger.log_request_response(
|
||
operation_id=operation_id,
|
||
request_method=req_meta[0],
|
||
request_url=req_meta[1],
|
||
response_status_code=exc.status,
|
||
response_headers=dict(req_meta[5]) if req_meta[5] else None,
|
||
response_content=response_content,
|
||
error_message=f"HTTP Error {exc.status}",
|
||
)
|
||
|
||
logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})")
|
||
if response_content:
|
||
logging.debug(f"[DEBUG] Response content: {response_content}")
|
||
|
||
# Retry if eligible
|
||
if status_code in self.retry_status_codes and retry_count < self.max_retries:
|
||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
||
logging.warning(
|
||
"HTTP error %s. Retrying in %.2fs (%s/%s)",
|
||
status_code,
|
||
delay,
|
||
retry_count + 1,
|
||
self.max_retries,
|
||
)
|
||
await asyncio.sleep(delay)
|
||
return await self.request(
|
||
req_meta[0], # method
|
||
req_meta[1].replace(self.base_url, ""), # path
|
||
params=req_meta[2],
|
||
data=req_meta[3],
|
||
files=req_meta[4],
|
||
headers=req_meta[5],
|
||
content_type=req_meta[6],
|
||
multipart_parser=req_meta[7],
|
||
retry_count=retry_count + 1,
|
||
)
|
||
|
||
raise Exception(user_friendly) from exc
|
||
|
||
@staticmethod
|
||
def _unpack_tuple(t):
|
||
"""Helper to normalise (filename, file, content_type) tuples."""
|
||
if len(t) == 3:
|
||
return t
|
||
elif len(t) == 2:
|
||
return t[0], t[1], "application/octet-stream"
|
||
else:
|
||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||
|
||
async def _get_session(self) -> aiohttp.ClientSession:
|
||
if self._session is None or self._session.closed:
|
||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||
self._owns_session = True
|
||
return self._session
|
||
|
||
async def close(self) -> None:
|
||
if self._owns_session and self._session and not self._session.closed:
|
||
await self._session.close()
|
||
|
||
async def __aenter__(self) -> "ApiClient":
|
||
"""Allow usage as async‑context‑manager – ensures clean teardown"""
|
||
return self
|
||
|
||
async def __aexit__(self, exc_type, exc, tb):
|
||
await self.close()
|
||
|
||
|
||
class ApiEndpoint(Generic[T, R]):
|
||
"""Defines an API endpoint with its request and response types"""
|
||
|
||
def __init__(
|
||
self,
|
||
path: str,
|
||
method: HttpMethod,
|
||
request_model: Type[T],
|
||
response_model: Type[R],
|
||
query_params: Optional[Dict[str, Any]] = None,
|
||
):
|
||
"""Initialize an API endpoint definition.
|
||
|
||
Args:
|
||
path: The URL path for this endpoint, can include placeholders like {id}
|
||
method: The HTTP method to use (GET, POST, etc.)
|
||
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
|
||
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
|
||
query_params: Optional dictionary of query parameters to include in the request
|
||
"""
|
||
self.path = path
|
||
self.method = method
|
||
self.request_model = request_model
|
||
self.response_model = response_model
|
||
self.query_params = query_params or {}
|
||
|
||
|
||
class SynchronousOperation(Generic[T, R]):
|
||
"""Represents a single synchronous API operation."""
|
||
|
||
def __init__(
|
||
self,
|
||
endpoint: ApiEndpoint[T, R],
|
||
request: T,
|
||
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
|
||
api_base: str | None = None,
|
||
auth_token: Optional[str] = None,
|
||
comfy_api_key: Optional[str] = None,
|
||
auth_kwargs: Optional[Dict[str, str]] = None,
|
||
timeout: float = 604800.0,
|
||
verify_ssl: bool = True,
|
||
content_type: str = "application/json",
|
||
multipart_parser: Callable | None = None,
|
||
max_retries: int = 3,
|
||
retry_delay: float = 1.0,
|
||
retry_backoff_factor: float = 2.0,
|
||
) -> None:
|
||
self.endpoint = endpoint
|
||
self.request = request
|
||
self.files = files
|
||
self.api_base: str = api_base or args.comfy_api_base
|
||
self.auth_token = auth_token
|
||
self.comfy_api_key = comfy_api_key
|
||
if auth_kwargs is not None:
|
||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
||
self.timeout = timeout
|
||
self.verify_ssl = verify_ssl
|
||
self.content_type = content_type
|
||
self.multipart_parser = multipart_parser
|
||
self.max_retries = max_retries
|
||
self.retry_delay = retry_delay
|
||
self.retry_backoff_factor = retry_backoff_factor
|
||
|
||
async def execute(self, client: Optional[ApiClient] = None) -> R:
|
||
owns_client = client is None
|
||
if owns_client:
|
||
client = ApiClient(
|
||
base_url=self.api_base,
|
||
auth_token=self.auth_token,
|
||
comfy_api_key=self.comfy_api_key,
|
||
timeout=self.timeout,
|
||
verify_ssl=self.verify_ssl,
|
||
max_retries=self.max_retries,
|
||
retry_delay=self.retry_delay,
|
||
retry_backoff_factor=self.retry_backoff_factor,
|
||
)
|
||
|
||
try:
|
||
request_dict: Optional[Dict[str, Any]]
|
||
if isinstance(self.request, EmptyRequest):
|
||
request_dict = None
|
||
else:
|
||
request_dict = self.request.model_dump(exclude_none=True)
|
||
for k, v in list(request_dict.items()):
|
||
if isinstance(v, Enum):
|
||
request_dict[k] = v.value
|
||
|
||
logging.debug(
|
||
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
|
||
)
|
||
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
|
||
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
|
||
|
||
response_json = await client.request(
|
||
self.endpoint.method.value,
|
||
self.endpoint.path,
|
||
params=self.endpoint.query_params,
|
||
data=request_dict,
|
||
files=self.files,
|
||
content_type=self.content_type,
|
||
multipart_parser=self.multipart_parser,
|
||
)
|
||
|
||
logging.debug("=" * 50)
|
||
logging.debug("[DEBUG] RESPONSE DETAILS:")
|
||
logging.debug("[DEBUG] Status Code: 200 (Success)")
|
||
logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}")
|
||
logging.debug("=" * 50)
|
||
|
||
parsed_response = self.endpoint.response_model.model_validate(response_json)
|
||
logging.debug(f"[DEBUG] Parsed Response: {parsed_response}")
|
||
return parsed_response
|
||
finally:
|
||
if owns_client:
|
||
await client.close()
|
||
|
||
|
||
class TaskStatus(str, Enum):
|
||
"""Enum for task status values"""
|
||
|
||
COMPLETED = "completed"
|
||
FAILED = "failed"
|
||
PENDING = "pending"
|
||
|
||
|
||
class PollingOperation(Generic[T, R]):
|
||
"""Represents an asynchronous API operation that requires polling for completion."""
|
||
|
||
def __init__(
|
||
self,
|
||
poll_endpoint: ApiEndpoint[EmptyRequest, R],
|
||
completed_statuses: list[str],
|
||
failed_statuses: list[str],
|
||
status_extractor: Callable[[R], str],
|
||
progress_extractor: Callable[[R], float] | None = None,
|
||
result_url_extractor: Callable[[R], str] | None = None,
|
||
request: Optional[T] = None,
|
||
api_base: str | None = None,
|
||
auth_token: Optional[str] = None,
|
||
comfy_api_key: Optional[str] = None,
|
||
auth_kwargs: Optional[Dict[str, str]] = None,
|
||
poll_interval: float = 5.0,
|
||
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
|
||
max_retries: int = 3, # Max retries per individual API call
|
||
retry_delay: float = 1.0,
|
||
retry_backoff_factor: float = 2.0,
|
||
estimated_duration: Optional[float] = None,
|
||
node_id: Optional[str] = None,
|
||
) -> None:
|
||
self.poll_endpoint = poll_endpoint
|
||
self.request = request
|
||
self.api_base: str = api_base or args.comfy_api_base
|
||
self.auth_token = auth_token
|
||
self.comfy_api_key = comfy_api_key
|
||
if auth_kwargs is not None:
|
||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
||
self.poll_interval = poll_interval
|
||
self.max_poll_attempts = max_poll_attempts
|
||
self.max_retries = max_retries
|
||
self.retry_delay = retry_delay
|
||
self.retry_backoff_factor = retry_backoff_factor
|
||
self.estimated_duration = estimated_duration
|
||
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None))
|
||
self.progress_extractor = progress_extractor
|
||
self.result_url_extractor = result_url_extractor
|
||
self.node_id = node_id
|
||
self.completed_statuses = completed_statuses
|
||
self.failed_statuses = failed_statuses
|
||
self.final_response: Optional[R] = None
|
||
|
||
async def execute(self, client: Optional[ApiClient] = None) -> R:
|
||
owns_client = client is None
|
||
if owns_client:
|
||
client = ApiClient(
|
||
base_url=self.api_base,
|
||
auth_token=self.auth_token,
|
||
comfy_api_key=self.comfy_api_key,
|
||
max_retries=self.max_retries,
|
||
retry_delay=self.retry_delay,
|
||
retry_backoff_factor=self.retry_backoff_factor,
|
||
)
|
||
try:
|
||
return await self._poll_until_complete(client)
|
||
finally:
|
||
if owns_client:
|
||
await client.close()
|
||
|
||
def _display_text_on_node(self, text: str):
|
||
if not self.node_id:
|
||
return
|
||
PromptServer.instance.send_progress_text(text, self.node_id)
|
||
|
||
def _display_time_progress_on_node(self, time_completed: int | float):
|
||
if not self.node_id:
|
||
return
|
||
if self.estimated_duration is not None:
|
||
remaining = max(0, int(self.estimated_duration) - time_completed)
|
||
message = f"Task in progress: {time_completed}s (~{remaining}s remaining)"
|
||
else:
|
||
message = f"Task in progress: {time_completed}s"
|
||
self._display_text_on_node(message)
|
||
|
||
def _check_task_status(self, response: R) -> TaskStatus:
|
||
try:
|
||
status = self.status_extractor(response)
|
||
if status in self.completed_statuses:
|
||
return TaskStatus.COMPLETED
|
||
if status in self.failed_statuses:
|
||
return TaskStatus.FAILED
|
||
return TaskStatus.PENDING
|
||
except Exception as e:
|
||
logging.error("Error extracting status: %s", e)
|
||
return TaskStatus.PENDING
|
||
|
||
async def _poll_until_complete(self, client: ApiClient) -> R:
|
||
"""Poll until the task is complete"""
|
||
consecutive_errors = 0
|
||
max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
|
||
|
||
if self.progress_extractor:
|
||
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
|
||
|
||
status = TaskStatus.PENDING
|
||
for poll_count in range(1, self.max_poll_attempts + 1):
|
||
try:
|
||
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
|
||
|
||
request_dict = (
|
||
None if self.request is None else self.request.model_dump(exclude_none=True)
|
||
)
|
||
|
||
if poll_count == 1:
|
||
logging.debug(
|
||
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}"
|
||
)
|
||
logging.debug(
|
||
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}"
|
||
)
|
||
|
||
# Query task status
|
||
resp = await client.request(
|
||
self.poll_endpoint.method.value,
|
||
self.poll_endpoint.path,
|
||
params=self.poll_endpoint.query_params,
|
||
data=request_dict,
|
||
)
|
||
consecutive_errors = 0 # reset on success
|
||
response_obj: R = self.poll_endpoint.response_model.model_validate(resp)
|
||
|
||
# Check if task is complete
|
||
status = self._check_task_status(response_obj)
|
||
logging.debug(f"[DEBUG] Task Status: {status}")
|
||
|
||
# If progress extractor is provided, extract progress
|
||
if self.progress_extractor:
|
||
new_progress = self.progress_extractor(response_obj)
|
||
if new_progress is not None:
|
||
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
|
||
|
||
if status == TaskStatus.COMPLETED:
|
||
message = "Task completed successfully"
|
||
if self.result_url_extractor:
|
||
result_url = self.result_url_extractor(response_obj)
|
||
if result_url:
|
||
message = f"Result URL: {result_url}"
|
||
logging.debug(f"[DEBUG] {message}")
|
||
self._display_text_on_node(message)
|
||
self.final_response = response_obj
|
||
if self.progress_extractor:
|
||
progress.update(100)
|
||
return self.final_response
|
||
if status == TaskStatus.FAILED:
|
||
message = f"Task failed: {json.dumps(resp)}"
|
||
logging.error(f"[DEBUG] {message}")
|
||
raise Exception(message)
|
||
logging.debug("[DEBUG] Task still pending, continuing to poll...")
|
||
# Task pending – wait
|
||
for i in range(int(self.poll_interval)):
|
||
self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i)
|
||
await asyncio.sleep(1)
|
||
|
||
except (LocalNetworkError, ApiServerError, NetworkError) as e:
|
||
consecutive_errors += 1
|
||
if consecutive_errors >= max_consecutive_errors:
|
||
raise Exception(
|
||
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
|
||
) from e
|
||
logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e))
|
||
await asyncio.sleep(self.poll_interval)
|
||
except Exception as e:
|
||
# For other errors, increment count and potentially abort
|
||
consecutive_errors += 1
|
||
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
|
||
raise Exception(
|
||
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
|
||
) from e
|
||
|
||
logging.error(f"[DEBUG] Polling error: {str(e)}")
|
||
logging.warning(
|
||
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
|
||
f"Will retry in {self.poll_interval} seconds."
|
||
)
|
||
await asyncio.sleep(self.poll_interval)
|
||
|
||
# If we've exhausted all polling attempts
|
||
raise Exception(
|
||
f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). "
|
||
"The operation may still be running on the server but is taking longer than expected."
|
||
)
|