ComfyAPI Core v0.0.2

This commit is contained in:
Jacob Segal 2025-07-16 15:24:26 -07:00
parent 4831e9c2c4
commit 780c3ead16
33 changed files with 1871 additions and 446 deletions

View File

@ -153,6 +153,7 @@ parser.add_argument("--disable-metadata", action="store_true", help="Disable sav
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
parser.add_argument("--generate-api-stubs", action="store_true", help="Generate .pyi stub files for API sync wrappers and exit.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")

View File

@ -0,0 +1,86 @@
#!/usr/bin/env python3
"""
Script to generate .pyi stub files for the synchronous API wrappers.
This allows generating stubs without running the full ComfyUI application.
"""
import os
import sys
import logging
import importlib
# Add ComfyUI to path so we can import modules
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from comfy_api.internal.async_to_sync import AsyncToSyncConverter
from comfy_api.version_list import supported_versions
def generate_stubs_for_module(module_name: str) -> None:
"""Generate stub files for a specific module that exports ComfyAPI and ComfyAPISync."""
try:
# Import the module
module = importlib.import_module(module_name)
# Check if module has ComfyAPISync (the sync wrapper)
if hasattr(module, "ComfyAPISync"):
# Module already has a sync class
api_class = getattr(module, "ComfyAPI", None)
sync_class = getattr(module, "ComfyAPISync")
if api_class:
# Generate the stub file
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
logging.info(f"Generated stub file for {module_name}")
else:
logging.warning(
f"Module {module_name} has ComfyAPISync but no ComfyAPI"
)
elif hasattr(module, "ComfyAPI"):
# Module only has async API, need to create sync wrapper first
from comfy_api.internal.async_to_sync import create_sync_class
api_class = getattr(module, "ComfyAPI")
sync_class = create_sync_class(api_class)
# Generate the stub file
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
logging.info(f"Generated stub file for {module_name}")
else:
logging.warning(
f"Module {module_name} does not export ComfyAPI or ComfyAPISync"
)
except Exception as e:
logging.error(f"Failed to generate stub for {module_name}: {e}")
import traceback
traceback.print_exc()
def main():
"""Main function to generate all API stub files."""
logging.basicConfig(level=logging.INFO)
logging.info("Starting stub generation...")
# Dynamically get module names from supported_versions
api_modules = []
for api_class in supported_versions:
# Extract module name from the class
module_name = api_class.__module__
if module_name not in api_modules:
api_modules.append(module_name)
logging.info(f"Found {len(api_modules)} API modules: {api_modules}")
# Generate stubs for each module
for module_name in api_modules:
generate_stubs_for_module(module_name)
logging.info("Stub generation complete!")
if __name__ == "__main__":
main()

View File

@ -1,8 +1,2 @@
from .basic_types import ImageInput, AudioInput
from .video_types import VideoInput
__all__ = [
"ImageInput",
"AudioInput",
"VideoInput",
]
# This file only exists for backwards compatibility.
from comfy_api.latest.input import * # noqa: F403

View File

@ -1,20 +1,2 @@
import torch
from typing import TypedDict
ImageInput = torch.Tensor
"""
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
"""
class AudioInput(TypedDict):
"""
TypedDict representing audio input.
"""
waveform: torch.Tensor
"""
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
"""
sample_rate: int
# This file only exists for backwards compatibility.
from comfy_api.latest.input.basic_types import * # noqa: F403

View File

@ -1,72 +1,2 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Union
import io
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
"""
Abstract base class for video input types.
"""
@abstractmethod
def get_components(self) -> VideoComponents:
"""
Abstract method to get the video components (images, audio, and frame rate).
Returns:
VideoComponents containing images, audio, and frame rate
"""
pass
@abstractmethod
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
"""
Abstract method to save the video input to a file.
"""
pass
def get_stream_source(self) -> Union[str, io.BytesIO]:
"""
Get a streamable source for the video. This allows processing without
loading the entire video into memory.
Returns:
Either a file path (str) or a BytesIO object that can be opened with av.
Default implementation creates a BytesIO buffer, but subclasses should
override this for better performance when possible.
"""
buffer = io.BytesIO()
self.save_to(buffer)
buffer.seek(0)
return buffer
# Provide a default implementation, but subclasses can provide optimized versions
# if possible.
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.
Returns:
Tuple of (width, height)
"""
components = self.get_components()
return components.images.shape[2], components.images.shape[1]
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
components = self.get_components()
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)
# This file only exists for backwards compatibility.
from comfy_api.latest.input.video_types import * # noqa: F403

View File

@ -1,7 +1,2 @@
from .video_types import VideoFromFile, VideoFromComponents
__all__ = [
# Implementations
"VideoFromFile",
"VideoFromComponents",
]
# This file only exists for backwards compatibility.
from comfy_api.latest.input_impl import * # noqa: F403

View File

@ -1,312 +1,2 @@
from __future__ import annotations
from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
from comfy_api.input import AudioInput
import av
import io
import json
import numpy as np
import torch
from comfy_api.input import VideoInput
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None:
"""
A container's `format` may be a comma-separated list of formats.
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
However, writing to a file/stream with `av.open` requires a single format,
or `None` to auto-detect.
"""
if not container_format:
return None # Auto-detect
if "," not in container_format:
return container_format
formats = container_format.split(",")
return formats[0]
def get_open_write_kwargs(
dest: str | io.BytesIO, container_format: str, to_format: str | None
) -> dict:
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
open_kwargs = {
"mode": "w",
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
"options": {"movflags": "use_metadata_tags"},
}
is_write_to_buffer = isinstance(dest, io.BytesIO)
if is_write_to_buffer:
# Set output format explicitly, since it cannot be inferred from file extension
if to_format == VideoContainer.AUTO:
to_format = container_format.lower()
elif isinstance(to_format, str):
to_format = to_format.lower()
open_kwargs["format"] = container_to_output_format(to_format)
return open_kwargs
class VideoFromFile(VideoInput):
"""
Class representing video input from a file.
"""
def __init__(self, file: str | io.BytesIO):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
def get_stream_source(self) -> str | io.BytesIO:
"""
Return the underlying file source for efficient streaming.
This avoids unnecessary memory copies when the source is already a file path.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
return self.__file
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.
Returns:
Tuple of (width, height)
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
for stream in container.streams:
if stream.type == 'video':
assert isinstance(stream, av.VideoStream)
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
if container.duration is not None:
return float(container.duration / av.time_base)
# Fallback: calculate from frame count and frame rate
video_stream = next(
(s for s in container.streams if s.type == "video"), None
)
if video_stream and video_stream.frames and video_stream.average_rate:
return float(video_stream.frames / video_stream.average_rate)
# Last resort: decode frames to count them
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_components_internal(self, container: InputContainer) -> VideoComponents:
# Get video frames
frames = []
for frame in container.decode(video=0):
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
# Get frame rate
video_stream = next(s for s in container.streams if s.type == 'video')
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
# Get audio if available
audio = None
try:
container.seek(0) # Reset the container to the beginning
for stream in container.streams:
if stream.type != 'audio':
continue
assert isinstance(stream, av.AudioStream)
audio_frames = []
for packet in container.demux(stream):
for frame in packet.decode():
assert isinstance(frame, av.AudioFrame)
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
})
except StopIteration:
pass # No audio stream
metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
return self.get_components_internal(container)
raise ValueError(f"No video stream found in file '{self.__file}'")
def save_to(
self,
path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
container_format = container.format.name
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
reuse_streams = True
if format != VideoContainer.AUTO and format not in container_format.split(","):
reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False
if not reuse_streams:
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
path,
format=format,
codec=codec,
metadata=metadata
)
streams = container.streams
open_kwargs = get_open_write_kwargs(path, container_format, format)
with av.open(path, **open_kwargs) as output_container:
# Copy over the original metadata
for key, value in container.metadata.items():
if metadata is None or key not in metadata:
output_container.metadata[key] = value
# Add our new metadata
if metadata is not None:
for key, value in metadata.items():
if isinstance(value, str):
output_container.metadata[key] = value
else:
output_container.metadata[key] = json.dumps(value)
# Add streams to the new container
stream_map = {}
for stream in streams:
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
stream_map[stream] = out_stream
# Write packets to the new container
for packet in container.demux():
if packet.stream in stream_map and packet.dts is not None:
packet.stream = stream_map[packet.stream]
output_container.mux(packet)
class VideoFromComponents(VideoInput):
"""
Class representing video input from tensors.
"""
def __init__(self, components: VideoComponents):
self.__components = components
def get_components(self) -> VideoComponents:
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
frame_rate=self.__components.frame_rate
)
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
# Add metadata before writing any streams
if metadata is not None:
for key, value in metadata.items():
output.metadata[key] = json.dumps(value)
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
# Create a video stream
video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = self.__components.images.shape[2]
video_stream.height = self.__components.images.shape[1]
video_stream.pix_fmt = 'yuv420p'
# Create an audio stream
audio_sample_rate = 1
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
audio_stream.sample_rate = audio_sample_rate
audio_stream.format = 'fltp'
# Encode video
for i, frame in enumerate(self.__components.images):
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
packet = video_stream.encode(frame)
output.mux(packet)
# Flush video
packet = video_stream.encode(None)
output.mux(packet)
if audio_stream and self.__components.audio:
# Encode audio
samples_per_frame = int(audio_sample_rate / frame_rate)
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
for i in range(num_frames):
start = i * samples_per_frame
end = start + samples_per_frame
# TODO(Feature) - Add support for stereo audio
chunk = (
self.__components.audio["waveform"][0, 0, start:end]
.unsqueeze(0)
.contiguous()
.numpy()
)
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
audio_frame.sample_rate = audio_sample_rate
audio_frame.pts = i * samples_per_frame
for packet in audio_stream.encode(audio_frame):
output.mux(packet)
# Flush audio
for packet in audio_stream.encode(None):
output.mux(packet)
# This file only exists for backwards compatibility.
from comfy_api.latest.input_impl.video_types import * # noqa: F403

View File

@ -0,0 +1,7 @@
# Internal infrastructure for ComfyAPI
from .api_registry import (
ComfyAPIBase as ComfyAPIBase,
ComfyAPIWithVersion as ComfyAPIWithVersion,
register_versions as register_versions,
get_all_versions as get_all_versions,
)

View File

@ -0,0 +1,39 @@
from typing import Type, List, NamedTuple
from comfy_api.internal.singleton import ProxiedSingleton
from packaging import version as packaging_version
class ComfyAPIBase(ProxiedSingleton):
def __init__(self):
pass
class ComfyAPIWithVersion(NamedTuple):
version: str
api_class: Type[ComfyAPIBase]
def parse_version(version_str: str) -> packaging_version.Version:
"""
Parses a version string into a packaging_version.Version object.
Raises ValueError if the version string is invalid.
"""
if version_str == "latest":
return packaging_version.parse("9999999.9999999.9999999")
return packaging_version.parse(version_str)
registered_versions: List[ComfyAPIWithVersion] = []
def register_versions(versions: List[ComfyAPIWithVersion]):
versions.sort(key=lambda x: parse_version(x.version))
global registered_versions
registered_versions = versions
def get_all_versions() -> List[ComfyAPIWithVersion]:
"""
Returns a list of all registered ComfyAPI versions.
"""
return registered_versions

View File

@ -0,0 +1,942 @@
import asyncio
import concurrent.futures
import contextvars
import functools
import inspect
import logging
import os
import textwrap
import threading
from enum import Enum
from typing import Optional, Type, get_origin, get_args
class TypeTracker:
"""Tracks types discovered during stub generation for automatic import generation."""
def __init__(self):
self.discovered_types = {} # type_name -> (module, qualname)
self.builtin_types = {
"Any",
"Dict",
"List",
"Optional",
"Tuple",
"Union",
"Set",
"Sequence",
"cast",
"NamedTuple",
"str",
"int",
"float",
"bool",
"None",
"bytes",
"object",
"type",
"dict",
"list",
"tuple",
"set",
}
self.already_imported = (
set()
) # Track types already imported to avoid duplicates
def track_type(self, annotation):
"""Track a type annotation and record its module/import info."""
if annotation is None or annotation is type(None):
return
# Skip builtins and typing module types we already import
type_name = getattr(annotation, "__name__", None)
if type_name and (
type_name in self.builtin_types or type_name in self.already_imported
):
return
# Get module and qualname
module = getattr(annotation, "__module__", None)
qualname = getattr(annotation, "__qualname__", type_name or "")
# Skip types from typing module (they're already imported)
if module == "typing":
return
# Skip UnionType and GenericAlias from types module as they're handled specially
if module == "types" and type_name in ("UnionType", "GenericAlias"):
return
if module and module not in ["builtins", "__main__"]:
# Store the type info
if type_name:
self.discovered_types[type_name] = (module, qualname)
def get_imports(self, main_module_name: str) -> list[str]:
"""Generate import statements for all discovered types."""
imports = []
imports_by_module = {}
for type_name, (module, qualname) in sorted(self.discovered_types.items()):
# Skip types from the main module (they're already imported)
if main_module_name and module == main_module_name:
continue
if module not in imports_by_module:
imports_by_module[module] = []
if type_name not in imports_by_module[module]: # Avoid duplicates
imports_by_module[module].append(type_name)
# Generate import statements
for module, types in sorted(imports_by_module.items()):
if len(types) == 1:
imports.append(f"from {module} import {types[0]}")
else:
imports.append(f"from {module} import {', '.join(sorted(set(types)))}")
return imports
class AsyncToSyncConverter:
"""
Provides utilities to convert async classes to sync classes with proper type hints.
"""
_thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None
_thread_pool_lock = threading.Lock()
_thread_pool_initialized = False
@classmethod
def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor:
"""Get or create the shared thread pool with proper thread-safe initialization."""
# Fast path - check if already initialized without acquiring lock
if cls._thread_pool_initialized:
assert cls._thread_pool is not None, "Thread pool should be initialized"
return cls._thread_pool
# Slow path - acquire lock and create pool if needed
with cls._thread_pool_lock:
if not cls._thread_pool_initialized:
cls._thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers, thread_name_prefix="async_to_sync_"
)
cls._thread_pool_initialized = True
# This should never be None at this point, but add assertion for type checker
assert cls._thread_pool is not None
return cls._thread_pool
@classmethod
def run_async_in_thread(cls, coro_func, *args, **kwargs):
"""
Run an async function in a separate thread from the thread pool.
Blocks until the async function completes.
Properly propagates contextvars between threads and manages event loops.
"""
# Capture current context - this includes all context variables
context = contextvars.copy_context()
# Store the result and any exception that occurs
result_container: dict = {"result": None, "exception": None}
# Function that runs in the thread pool
def run_in_thread():
# Create new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Create the coroutine within the context
async def run_with_context():
# The coroutine function might access context variables
return await coro_func(*args, **kwargs)
# Run the coroutine with the captured context
# This ensures all context variables are available in the async function
result = context.run(loop.run_until_complete, run_with_context())
result_container["result"] = result
except Exception as e:
# Store the exception to re-raise in the calling thread
result_container["exception"] = e
finally:
# Ensure event loop is properly closed to prevent warnings
try:
# Cancel any remaining tasks
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()
# Run the loop briefly to handle cancellations
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
except Exception:
pass # Ignore errors during cleanup
# Close the event loop
loop.close()
# Clear the event loop from the thread
asyncio.set_event_loop(None)
# Submit to thread pool and wait for result
thread_pool = cls.get_thread_pool()
future = thread_pool.submit(run_in_thread)
future.result() # Wait for completion
# Re-raise any exception that occurred in the thread
if result_container["exception"] is not None:
raise result_container["exception"]
return result_container["result"]
@classmethod
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
"""
Creates a new class with synchronous versions of all async methods.
Args:
async_class: The async class to convert
thread_pool_size: Size of thread pool to use
Returns:
A new class with sync versions of all async methods
"""
sync_class_name = "ComfyAPISyncStub"
cls.get_thread_pool(thread_pool_size)
# Create a proper class with docstrings and proper base classes
sync_class_dict = {
"__doc__": async_class.__doc__,
"__module__": async_class.__module__,
"__qualname__": sync_class_name,
"__orig_class__": async_class, # Store original class for typing references
}
# Create __init__ method
def __init__(self, *args, **kwargs):
self._async_instance = async_class(*args, **kwargs)
# Handle annotated class attributes (like execution: Execution)
# Get all annotations from the class hierarchy
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
# For each annotated attribute, check if it needs to be created or wrapped
for attr_name, attr_type in all_annotations.items():
if hasattr(self._async_instance, attr_name):
# Attribute exists on the instance
attr = getattr(self._async_instance, attr_name)
# Check if this attribute needs a sync wrapper
if hasattr(attr, "__class__"):
from comfy_api.internal.singleton import ProxiedSingleton
if isinstance(attr, ProxiedSingleton):
# Create a sync version of this attribute
try:
sync_attr_class = cls.create_sync_class(attr.__class__)
# Create instance of the sync wrapper with the async instance
sync_attr = object.__new__(sync_attr_class) # type: ignore
sync_attr._async_instance = attr
setattr(self, attr_name, sync_attr)
except Exception:
# If we can't create a sync version, keep the original
setattr(self, attr_name, attr)
else:
# Not async, just copy the reference
setattr(self, attr_name, attr)
else:
# Attribute doesn't exist, but is annotated - create it
# This handles cases like execution: Execution
if isinstance(attr_type, type):
# Check if the type is defined as an inner class
if hasattr(async_class, attr_type.__name__):
inner_class = getattr(async_class, attr_type.__name__)
from comfy_api.internal.singleton import ProxiedSingleton
# Create an instance of the inner class
try:
# For ProxiedSingleton classes, get or create the singleton instance
if issubclass(inner_class, ProxiedSingleton):
async_instance = inner_class.get_instance()
else:
async_instance = inner_class()
# Create sync wrapper
sync_attr_class = cls.create_sync_class(inner_class)
sync_attr = object.__new__(sync_attr_class) # type: ignore
sync_attr._async_instance = async_instance
setattr(self, attr_name, sync_attr)
# Also set on the async instance for consistency
setattr(self._async_instance, attr_name, async_instance)
except Exception as e:
logging.warning(
f"Failed to create instance for {attr_name}: {e}"
)
# Handle other instance attributes that might not be annotated
for name, attr in inspect.getmembers(self._async_instance):
if name.startswith("_") or hasattr(self, name):
continue
# If attribute is an instance of a class, and that class is defined in the original class
# we need to check if it needs a sync wrapper
if isinstance(attr, object) and not isinstance(
attr, (str, int, float, bool, list, dict, tuple)
):
from comfy_api.internal.singleton import ProxiedSingleton
if isinstance(attr, ProxiedSingleton):
# Create a sync version of this nested class
try:
sync_attr_class = cls.create_sync_class(attr.__class__)
# Create instance of the sync wrapper with the async instance
sync_attr = object.__new__(sync_attr_class) # type: ignore
sync_attr._async_instance = attr
setattr(self, name, sync_attr)
except Exception:
# If we can't create a sync version, keep the original
setattr(self, name, attr)
sync_class_dict["__init__"] = __init__
# Process methods from the async class
for name, method in inspect.getmembers(
async_class, predicate=inspect.isfunction
):
if name.startswith("_"):
continue
# Extract the actual return type from a coroutine
if inspect.iscoroutinefunction(method):
# Create sync version of async method with proper signature
@functools.wraps(method)
def sync_method(self, *args, _method_name=name, **kwargs):
async_method = getattr(self._async_instance, _method_name)
return AsyncToSyncConverter.run_async_in_thread(
async_method, *args, **kwargs
)
# Add to the class dict
sync_class_dict[name] = sync_method
else:
# For regular methods, create a proxy method
@functools.wraps(method)
def proxy_method(self, *args, _method_name=name, **kwargs):
method = getattr(self._async_instance, _method_name)
return method(*args, **kwargs)
# Add to the class dict
sync_class_dict[name] = proxy_method
# Handle property access
for name, prop in inspect.getmembers(
async_class, lambda x: isinstance(x, property)
):
def make_property(name, prop_obj):
def getter(self):
value = getattr(self._async_instance, name)
if inspect.iscoroutinefunction(value):
def sync_fn(*args, **kwargs):
return AsyncToSyncConverter.run_async_in_thread(
value, *args, **kwargs
)
return sync_fn
return value
def setter(self, value):
setattr(self._async_instance, name, value)
return property(getter, setter if prop_obj.fset else None)
sync_class_dict[name] = make_property(name, prop)
# Create the class
sync_class = type(sync_class_name, (object,), sync_class_dict)
return sync_class
@classmethod
def _format_type_annotation(
cls, annotation, type_tracker: Optional[TypeTracker] = None
) -> str:
"""Convert a type annotation to its string representation for stub files."""
if (
annotation is inspect.Parameter.empty
or annotation is inspect.Signature.empty
):
return "Any"
# Handle None type
if annotation is type(None):
return "None"
# Track the type if we have a tracker
if type_tracker:
type_tracker.track_type(annotation)
# Try using typing.get_origin/get_args for Python 3.8+
try:
origin = get_origin(annotation)
args = get_args(annotation)
if origin is not None:
# Track the origin type
if type_tracker:
type_tracker.track_type(origin)
# Get the origin name
origin_name = getattr(origin, "__name__", str(origin))
if "." in origin_name:
origin_name = origin_name.split(".")[-1]
# Special handling for types.UnionType (Python 3.10+ pipe operator)
if origin_name == "UnionType":
origin_name = "Union"
# Format arguments recursively
if args:
formatted_args = [
cls._format_type_annotation(arg, type_tracker) for arg in args
]
return f"{origin_name}[{', '.join(formatted_args)}]"
else:
return origin_name
except (AttributeError, TypeError):
# Fallback for older Python versions or non-generic types
pass
# Handle generic types the old way for compatibility
if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"):
origin = annotation.__origin__
origin_name = (
origin.__name__
if hasattr(origin, "__name__")
else str(origin).split("'")[1]
)
# Format each type argument
args = []
for arg in annotation.__args__:
args.append(cls._format_type_annotation(arg, type_tracker))
return f"{origin_name}[{', '.join(args)}]"
# Handle regular types with __name__
if hasattr(annotation, "__name__"):
return annotation.__name__
# Handle special module types (like types from typing module)
if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"):
# For types like typing.Literal, typing.TypedDict, etc.
return annotation.__qualname__
# Last resort: string conversion with cleanup
type_str = str(annotation)
# Clean up common patterns more robustly
if type_str.startswith("<class '") and type_str.endswith("'>"):
type_str = type_str[8:-2] # Remove "<class '" and "'>"
# Remove module prefixes for common modules
for prefix in ["typing.", "builtins.", "types."]:
if type_str.startswith(prefix):
type_str = type_str[len(prefix) :]
# Handle special cases
if type_str in ("_empty", "inspect._empty"):
return "None"
# Fix NoneType (this should rarely be needed now)
if type_str == "NoneType":
return "None"
return type_str
@classmethod
def _extract_coroutine_return_type(cls, annotation):
"""Extract the actual return type from a Coroutine annotation."""
if hasattr(annotation, "__args__") and len(annotation.__args__) > 2:
# Coroutine[Any, Any, ReturnType] -> extract ReturnType
return annotation.__args__[2]
return annotation
@classmethod
def _format_parameter_default(cls, default_value) -> str:
"""Format a parameter's default value for stub files."""
if default_value is inspect.Parameter.empty:
return ""
elif default_value is None:
return " = None"
elif isinstance(default_value, bool):
return f" = {default_value}"
elif default_value == {}:
return " = {}"
elif default_value == []:
return " = []"
else:
return f" = {default_value}"
@classmethod
def _format_method_parameters(
cls,
sig: inspect.Signature,
skip_self: bool = True,
type_tracker: Optional[TypeTracker] = None,
) -> str:
"""Format method parameters for stub files."""
params = []
for i, (param_name, param) in enumerate(sig.parameters.items()):
if i == 0 and param_name == "self" and skip_self:
params.append("self")
else:
# Get type annotation
type_str = cls._format_type_annotation(param.annotation, type_tracker)
# Get default value
default_str = cls._format_parameter_default(param.default)
# Combine parameter parts
if param.annotation is inspect.Parameter.empty:
params.append(f"{param_name}: Any{default_str}")
else:
params.append(f"{param_name}: {type_str}{default_str}")
return ", ".join(params)
@classmethod
def _generate_method_signature(
cls,
method_name: str,
method,
is_async: bool = False,
type_tracker: Optional[TypeTracker] = None,
) -> str:
"""Generate a complete method signature for stub files."""
sig = inspect.signature(method)
# For async methods, extract the actual return type
return_annotation = sig.return_annotation
if is_async and inspect.iscoroutinefunction(method):
return_annotation = cls._extract_coroutine_return_type(return_annotation)
# Format parameters
params_str = cls._format_method_parameters(sig, type_tracker=type_tracker)
# Format return type
return_type = cls._format_type_annotation(return_annotation, type_tracker)
if return_annotation is inspect.Signature.empty:
return_type = "None"
return f"def {method_name}({params_str}) -> {return_type}: ..."
@classmethod
def _generate_imports(
cls, async_class: Type, type_tracker: TypeTracker
) -> list[str]:
"""Generate import statements for the stub file."""
imports = []
# Add standard typing imports
imports.append(
"from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple"
)
# Add imports from the original module
if async_class.__module__ != "builtins":
module = inspect.getmodule(async_class)
additional_types = []
if module:
for name, obj in sorted(inspect.getmembers(module)):
if isinstance(obj, type):
# Check for NamedTuple
if issubclass(obj, tuple) and hasattr(obj, "_fields"):
additional_types.append(name)
# Mark as already imported
type_tracker.already_imported.add(name)
# Check for Enum
elif issubclass(obj, Enum) and name != "Enum":
additional_types.append(name)
# Mark as already imported
type_tracker.already_imported.add(name)
if additional_types:
type_imports = ", ".join([async_class.__name__] + additional_types)
imports.append(f"from {async_class.__module__} import {type_imports}")
else:
imports.append(
f"from {async_class.__module__} import {async_class.__name__}"
)
# Add imports for all discovered types
# Pass the main module name to avoid duplicate imports
imports.extend(
type_tracker.get_imports(main_module_name=async_class.__module__)
)
# Add base module import if needed
if hasattr(inspect.getmodule(async_class), "__name__"):
module_name = inspect.getmodule(async_class).__name__
if "." in module_name:
base_module = module_name.split(".")[0]
# Only add if not already importing from it
if not any(imp.startswith(f"from {base_module}") for imp in imports):
imports.append(f"import {base_module}")
return imports
@classmethod
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
"""Extract class attributes that are classes themselves."""
class_attributes = []
# Look for class attributes that are classes
for name, attr in sorted(inspect.getmembers(async_class)):
if isinstance(attr, type) and not name.startswith("_"):
class_attributes.append((name, attr))
elif (
hasattr(async_class, "__annotations__")
and name in async_class.__annotations__
):
annotation = async_class.__annotations__[name]
if isinstance(annotation, type):
class_attributes.append((name, annotation))
return class_attributes
@classmethod
def _generate_inner_class_stub(
cls,
name: str,
attr: Type,
indent: str = " ",
type_tracker: Optional[TypeTracker] = None,
) -> list[str]:
"""Generate stub for an inner class."""
stub_lines = []
stub_lines.append(f"{indent}class {name}Sync:")
# Add docstring if available
if hasattr(attr, "__doc__") and attr.__doc__:
stub_lines.extend(
cls._format_docstring_for_stub(attr.__doc__, f"{indent} ")
)
# Add __init__ if it exists
if hasattr(attr, "__init__"):
try:
init_method = getattr(attr, "__init__")
init_sig = inspect.signature(init_method)
# Format parameters
params_str = cls._format_method_parameters(
init_sig, type_tracker=type_tracker
)
# Add __init__ docstring if available (before the method)
if hasattr(init_method, "__doc__") and init_method.__doc__:
stub_lines.extend(
cls._format_docstring_for_stub(
init_method.__doc__, f"{indent} "
)
)
stub_lines.append(
f"{indent} def __init__({params_str}) -> None: ..."
)
except (ValueError, TypeError):
stub_lines.append(
f"{indent} def __init__(self, *args, **kwargs) -> None: ..."
)
# Add methods to the inner class
has_methods = False
for method_name, method in sorted(
inspect.getmembers(attr, predicate=inspect.isfunction)
):
if method_name.startswith("_"):
continue
has_methods = True
try:
# Add method docstring if available (before the method signature)
if method.__doc__:
stub_lines.extend(
cls._format_docstring_for_stub(method.__doc__, f"{indent} ")
)
method_sig = cls._generate_method_signature(
method_name, method, is_async=True, type_tracker=type_tracker
)
stub_lines.append(f"{indent} {method_sig}")
except (ValueError, TypeError):
stub_lines.append(
f"{indent} def {method_name}(self, *args, **kwargs): ..."
)
if not has_methods:
stub_lines.append(f"{indent} pass")
return stub_lines
@classmethod
def _format_docstring_for_stub(
cls, docstring: str, indent: str = " "
) -> list[str]:
"""Format a docstring for inclusion in a stub file with proper indentation."""
if not docstring:
return []
# First, dedent the docstring to remove any existing indentation
dedented = textwrap.dedent(docstring).strip()
# Split into lines
lines = dedented.split("\n")
# Build the properly indented docstring
result = []
result.append(f'{indent}"""')
for line in lines:
if line.strip(): # Non-empty line
result.append(f"{indent}{line}")
else: # Empty line
result.append("")
result.append(f'{indent}"""')
return result
@classmethod
def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]:
"""Post-process stub content to fix any remaining issues."""
processed = []
for line in stub_content:
# Skip processing imports
if line.startswith(("from ", "import ")):
processed.append(line)
continue
# Fix method signatures missing return types
if (
line.strip().startswith("def ")
and line.strip().endswith(": ...")
and ") -> " not in line
):
# Add -> None for methods without return annotation
line = line.replace(": ...", " -> None: ...")
processed.append(line)
return processed
@classmethod
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
"""
Generate a .pyi stub file for the sync class to help IDEs with type checking.
"""
try:
# Only generate stub if we can determine module path
if async_class.__module__ == "__main__":
return
module = inspect.getmodule(async_class)
if not module:
return
module_path = module.__file__
if not module_path:
return
# Create stub file path in a 'generated' subdirectory
module_dir = os.path.dirname(module_path)
stub_dir = os.path.join(module_dir, "generated")
# Ensure the generated directory exists
os.makedirs(stub_dir, exist_ok=True)
module_name = os.path.basename(module_path)
if module_name.endswith(".py"):
module_name = module_name[:-3]
sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi")
# Create a type tracker for this stub generation
type_tracker = TypeTracker()
stub_content = []
# We'll generate imports after processing all methods to capture all types
# Leave a placeholder for imports
imports_placeholder_index = len(stub_content)
stub_content.append("") # Will be replaced with imports later
# Class definition
stub_content.append(f"class {sync_class.__name__}:")
# Docstring
if async_class.__doc__:
stub_content.extend(
cls._format_docstring_for_stub(async_class.__doc__, " ")
)
# Generate __init__
try:
init_method = async_class.__init__
init_signature = inspect.signature(init_method)
# Format parameters
params_str = cls._format_method_parameters(
init_signature, type_tracker=type_tracker
)
# Add __init__ docstring if available (before the method)
if hasattr(init_method, "__doc__") and init_method.__doc__:
stub_content.extend(
cls._format_docstring_for_stub(init_method.__doc__, " ")
)
stub_content.append(f" def __init__({params_str}) -> None: ...")
except (ValueError, TypeError):
stub_content.append(
" def __init__(self, *args, **kwargs) -> None: ..."
)
stub_content.append("") # Add newline after __init__
# Get class attributes
class_attributes = cls._get_class_attributes(async_class)
# Generate inner classes
for name, attr in class_attributes:
inner_class_stub = cls._generate_inner_class_stub(
name, attr, type_tracker=type_tracker
)
stub_content.extend(inner_class_stub)
stub_content.append("") # Add newline after the inner class
# Add methods to the main class
processed_methods = set() # Keep track of methods we've processed
for name, method in sorted(
inspect.getmembers(async_class, predicate=inspect.isfunction)
):
if name.startswith("_") or name in processed_methods:
continue
processed_methods.add(name)
try:
method_sig = cls._generate_method_signature(
name, method, is_async=True, type_tracker=type_tracker
)
# Add docstring if available (before the method signature for proper formatting)
if method.__doc__:
stub_content.extend(
cls._format_docstring_for_stub(method.__doc__, " ")
)
stub_content.append(f" {method_sig}")
stub_content.append("") # Add newline after each method
except (ValueError, TypeError):
# If we can't get the signature, just add a simple stub
stub_content.append(f" def {name}(self, *args, **kwargs): ...")
stub_content.append("") # Add newline
# Add properties
for name, prop in sorted(
inspect.getmembers(async_class, lambda x: isinstance(x, property))
):
stub_content.append(" @property")
stub_content.append(f" def {name}(self) -> Any: ...")
if prop.fset:
stub_content.append(f" @{name}.setter")
stub_content.append(
f" def {name}(self, value: Any) -> None: ..."
)
stub_content.append("") # Add newline after each property
# Add placeholders for the nested class instances
# Check the actual attribute names from class annotations and attributes
attribute_mappings = {}
# First check annotations for typed attributes (including from parent classes)
# Collect all annotations from the class hierarchy
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
for attr_name, attr_type in sorted(all_annotations.items()):
for class_name, class_type in class_attributes:
# If the class type matches the annotated type
if attr_type == class_type or (
hasattr(attr_type, "__name__")
and attr_type.__name__ == class_name
):
attribute_mappings[class_name] = attr_name
# Remove the extra checking - annotations should be sufficient
# Add the attribute declarations with proper names
for class_name, _ in class_attributes:
# Use the attribute name if found in mappings, otherwise use class name
attr_name = attribute_mappings.get(class_name, class_name)
stub_content.append(f" {attr_name}: {class_name}Sync")
stub_content.append("") # Add a final newline
# Now generate imports with all discovered types
imports = cls._generate_imports(async_class, type_tracker)
# Deduplicate imports while preserving order
seen = set()
unique_imports = []
for imp in imports:
if imp not in seen:
seen.add(imp)
unique_imports.append(imp)
else:
logging.warning(f"Duplicate import detected: {imp}")
# Replace the placeholder with actual imports
stub_content[imports_placeholder_index : imports_placeholder_index + 1] = (
unique_imports
)
# Post-process stub content
stub_content = cls._post_process_stub_content(stub_content)
# Write stub file
with open(sync_stub_path, "w") as f:
f.write("\n".join(stub_content))
logging.info(f"Generated stub file: {sync_stub_path}")
except Exception as e:
# If stub generation fails, log the error but don't break the main functionality
logging.error(
f"Error generating stub file for {sync_class.__name__}: {str(e)}"
)
import traceback
logging.error(traceback.format_exc())
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
"""
Creates a sync version of an async class
Args:
async_class: The async class to convert
thread_pool_size: Size of thread pool to use
Returns:
A new class with sync versions of all async methods
"""
return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size)

View File

@ -0,0 +1,33 @@
from typing import Type, TypeVar
class SingletonMetaclass(type):
T = TypeVar("T", bound="SingletonMetaclass")
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(SingletonMetaclass, cls).__call__(
*args, **kwargs
)
return cls._instances[cls]
def inject_instance(cls: Type[T], instance: T) -> None:
assert cls not in SingletonMetaclass._instances, (
"Cannot inject instance after first instantiation"
)
SingletonMetaclass._instances[cls] = instance
def get_instance(cls: Type[T], *args, **kwargs) -> T:
"""
Gets the singleton instance of the class, creating it if it doesn't exist.
"""
if cls not in SingletonMetaclass._instances:
SingletonMetaclass._instances[cls] = super(
SingletonMetaclass, cls
).__call__(*args, **kwargs)
return cls._instances[cls]
class ProxiedSingleton(object, metaclass=SingletonMetaclass):
def __init__(self):
super().__init__()

View File

@ -0,0 +1,76 @@
from typing import Type, TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest.input import ImageInput
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
from PIL import Image
from comfy.cli_args import args
import numpy as np
class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest"
STABLE = False
class Execution(ProxiedSingleton):
async def set_progress(
self,
value: float,
max_value: float,
node_id: str | None = None,
preview_image: PreviewImageTuple | Image.Image | ImageInput | None = None,
) -> None:
"""
Update the progress bar displayed in the ComfyUI interface.
This function allows custom nodes and API calls to report their progress
back to the user interface, providing visual feedback during long operations.
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
"""
executing_context = get_executing_context()
if node_id is None and executing_context is not None:
node_id = executing_context.node_id
if node_id is None:
raise ValueError("node_id must be provided if not in executing context")
# Convert preview_image to PreviewImageTuple if needed
if preview_image is not None:
# First convert to PIL Image if needed
if isinstance(preview_image, ImageInput):
# Convert ImageInput (torch.Tensor) to PIL Image
# Handle tensor shape [B, H, W, C] -> get first image if batch
tensor = preview_image
if len(tensor.shape) == 4:
tensor = tensor[0]
# Convert to numpy array and scale to 0-255
image_np = (tensor.cpu().numpy() * 255).astype(np.uint8)
preview_image = Image.fromarray(image_np)
if isinstance(preview_image, Image.Image):
# Detect image format from PIL Image
image_format = preview_image.format if preview_image.format else "JPEG"
preview_image = (image_format, preview_image, args.preview_size)
get_progress_state().update_progress(
node_id=node_id,
value=value,
max_value=max_value,
image=preview_image,
)
execution: Execution
ComfyAPI = ComfyAPI_latest
# Create a synchronous version of the API
if TYPE_CHECKING:
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest)

View File

@ -0,0 +1,20 @@
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
from comfy_api.latest import ComfyAPI_latest
from PIL.Image import Image
from torch import Tensor
class ComfyAPISyncStub:
def __init__(self) -> None: ...
class ExecutionSync:
def __init__(self) -> None: ...
"""
Update the progress bar displayed in the ComfyUI interface.
This function allows custom nodes and API calls to report their progress
back to the user interface, providing visual feedback during long operations.
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
"""
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[tuple[str, Image, Union[int, None]], Image, Tensor, None] = None) -> None: ...
execution: ExecutionSync

View File

@ -0,0 +1,10 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .video_types import VideoInput
__all__ = [
"ImageInput",
"AudioInput",
"VideoInput",
"MaskInput",
"LatentInput",
]

View File

@ -0,0 +1,42 @@
import torch
from typing import TypedDict, List, Optional
ImageInput = torch.Tensor
"""
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
"""
MaskInput = torch.Tensor
"""
A mask in format [B, H, W] where B is the batch size
"""
class AudioInput(TypedDict):
"""
TypedDict representing audio input.
"""
waveform: torch.Tensor
"""
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
"""
sample_rate: int
class LatentInput(TypedDict):
"""
TypedDict representing latent input.
"""
samples: torch.Tensor
"""
Tensor in the format [B, C, H, W] where B is the batch size, C is the number of channels,
H is the height, and W is the width.
"""
noise_mask: Optional[MaskInput]
"""
Optional noise mask tensor in the same format as samples.
"""
batch_index: Optional[List[int]]

View File

@ -0,0 +1,72 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Union
import io
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
"""
Abstract base class for video input types.
"""
@abstractmethod
def get_components(self) -> VideoComponents:
"""
Abstract method to get the video components (images, audio, and frame rate).
Returns:
VideoComponents containing images, audio, and frame rate
"""
pass
@abstractmethod
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
"""
Abstract method to save the video input to a file.
"""
pass
def get_stream_source(self) -> Union[str, io.BytesIO]:
"""
Get a streamable source for the video. This allows processing without
loading the entire video into memory.
Returns:
Either a file path (str) or a BytesIO object that can be opened with av.
Default implementation creates a BytesIO buffer, but subclasses should
override this for better performance when possible.
"""
buffer = io.BytesIO()
self.save_to(buffer)
buffer.seek(0)
return buffer
# Provide a default implementation, but subclasses can provide optimized versions
# if possible.
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.
Returns:
Tuple of (width, height)
"""
components = self.get_components()
return components.images.shape[2], components.images.shape[1]
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
components = self.get_components()
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)

View File

@ -0,0 +1,7 @@
from .video_types import VideoFromFile, VideoFromComponents
__all__ = [
# Implementations
"VideoFromFile",
"VideoFromComponents",
]

View File

@ -0,0 +1,312 @@
from __future__ import annotations
from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
from comfy_api.latest.input import AudioInput, VideoInput
import av
import io
import json
import numpy as np
import torch
from comfy_api.latest.util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None:
"""
A container's `format` may be a comma-separated list of formats.
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
However, writing to a file/stream with `av.open` requires a single format,
or `None` to auto-detect.
"""
if not container_format:
return None # Auto-detect
if "," not in container_format:
return container_format
formats = container_format.split(",")
return formats[0]
def get_open_write_kwargs(
dest: str | io.BytesIO, container_format: str, to_format: str | None
) -> dict:
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
open_kwargs = {
"mode": "w",
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
"options": {"movflags": "use_metadata_tags"},
}
is_write_to_buffer = isinstance(dest, io.BytesIO)
if is_write_to_buffer:
# Set output format explicitly, since it cannot be inferred from file extension
if to_format == VideoContainer.AUTO:
to_format = container_format.lower()
elif isinstance(to_format, str):
to_format = to_format.lower()
open_kwargs["format"] = container_to_output_format(to_format)
return open_kwargs
class VideoFromFile(VideoInput):
"""
Class representing video input from a file.
"""
def __init__(self, file: str | io.BytesIO):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
def get_stream_source(self) -> str | io.BytesIO:
"""
Return the underlying file source for efficient streaming.
This avoids unnecessary memory copies when the source is already a file path.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
return self.__file
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.
Returns:
Tuple of (width, height)
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
for stream in container.streams:
if stream.type == 'video':
assert isinstance(stream, av.VideoStream)
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
if container.duration is not None:
return float(container.duration / av.time_base)
# Fallback: calculate from frame count and frame rate
video_stream = next(
(s for s in container.streams if s.type == "video"), None
)
if video_stream and video_stream.frames and video_stream.average_rate:
return float(video_stream.frames / video_stream.average_rate)
# Last resort: decode frames to count them
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_components_internal(self, container: InputContainer) -> VideoComponents:
# Get video frames
frames = []
for frame in container.decode(video=0):
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
# Get frame rate
video_stream = next(s for s in container.streams if s.type == 'video')
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
# Get audio if available
audio = None
try:
container.seek(0) # Reset the container to the beginning
for stream in container.streams:
if stream.type != 'audio':
continue
assert isinstance(stream, av.AudioStream)
audio_frames = []
for packet in container.demux(stream):
for frame in packet.decode():
assert isinstance(frame, av.AudioFrame)
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
})
except StopIteration:
pass # No audio stream
metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
return self.get_components_internal(container)
raise ValueError(f"No video stream found in file '{self.__file}'")
def save_to(
self,
path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
container_format = container.format.name
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
reuse_streams = True
if format != VideoContainer.AUTO and format not in container_format.split(","):
reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False
if not reuse_streams:
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
path,
format=format,
codec=codec,
metadata=metadata
)
streams = container.streams
open_kwargs = get_open_write_kwargs(path, container_format, format)
with av.open(path, **open_kwargs) as output_container:
# Copy over the original metadata
for key, value in container.metadata.items():
if metadata is None or key not in metadata:
output_container.metadata[key] = value
# Add our new metadata
if metadata is not None:
for key, value in metadata.items():
if isinstance(value, str):
output_container.metadata[key] = value
else:
output_container.metadata[key] = json.dumps(value)
# Add streams to the new container
stream_map = {}
for stream in streams:
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
stream_map[stream] = out_stream
# Write packets to the new container
for packet in container.demux():
if packet.stream in stream_map and packet.dts is not None:
packet.stream = stream_map[packet.stream]
output_container.mux(packet)
class VideoFromComponents(VideoInput):
"""
Class representing video input from tensors.
"""
def __init__(self, components: VideoComponents):
self.__components = components
def get_components(self) -> VideoComponents:
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
frame_rate=self.__components.frame_rate
)
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
# Add metadata before writing any streams
if metadata is not None:
for key, value in metadata.items():
output.metadata[key] = json.dumps(value)
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
# Create a video stream
video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = self.__components.images.shape[2]
video_stream.height = self.__components.images.shape[1]
video_stream.pix_fmt = 'yuv420p'
# Create an audio stream
audio_sample_rate = 1
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
audio_stream.sample_rate = audio_sample_rate
audio_stream.format = 'fltp'
# Encode video
for i, frame in enumerate(self.__components.images):
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
packet = video_stream.encode(frame)
output.mux(packet)
# Flush video
packet = video_stream.encode(None)
output.mux(packet)
if audio_stream and self.__components.audio:
# Encode audio
samples_per_frame = int(audio_sample_rate / frame_rate)
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
for i in range(num_frames):
start = i * samples_per_frame
end = start + samples_per_frame
# TODO(Feature) - Add support for stereo audio
chunk = (
self.__components.audio["waveform"][0, 0, start:end]
.unsqueeze(0)
.contiguous()
.numpy()
)
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
audio_frame.sample_rate = audio_sample_rate
audio_frame.pts = i * samples_per_frame
for packet in audio_stream.encode(audio_frame):
output.mux(packet)
# Flush audio
for packet in audio_stream.encode(None):
output.mux(packet)

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from enum import Enum
from fractions import Fraction
from typing import Optional
from comfy_api.input import ImageInput, AudioInput
from comfy_api.latest.input import ImageInput, AudioInput
class VideoCodec(str, Enum):
AUTO = "auto"
@ -49,3 +49,4 @@ class VideoComponents:
audio: Optional[AudioInput] = None
metadata: Optional[dict] = None

2
comfy_api/util.py Normal file
View File

@ -0,0 +1,2 @@
# This file only exists for backwards compatibility.
from comfy_api.latest.util import * # noqa: F403

View File

@ -0,0 +1,18 @@
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
from typing import Type, TYPE_CHECKING
from comfy_api.internal.async_to_sync import create_sync_class
# This version only exists to serve as a template for future version adapters.
# There is no reason anyone should ever use it.
class ComfyAPIAdapter_v0_0_1(ComfyAPIAdapter_v0_0_2):
VERSION = "0.0.1"
STABLE = True
ComfyAPI = ComfyAPIAdapter_v0_0_1
# Create a synchronous version of the API
if TYPE_CHECKING:
from comfy_api.v0_0_1.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore
ComfyAPISync: Type[ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_1)

View File

@ -0,0 +1,20 @@
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
from PIL.Image import Image
from torch import Tensor
class ComfyAPISyncStub:
def __init__(self) -> None: ...
class ExecutionSync:
def __init__(self) -> None: ...
"""
Update the progress bar displayed in the ComfyUI interface.
This function allows custom nodes and API calls to report their progress
back to the user interface, providing visual feedback during long operations.
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
"""
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[tuple[str, Image, Union[int, None]], Image, Tensor, None] = None) -> None: ...
execution: ExecutionSync

View File

@ -0,0 +1,15 @@
from comfy_api.latest import ComfyAPI_latest
from typing import Type, TYPE_CHECKING
from comfy_api.internal.async_to_sync import create_sync_class
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
VERSION = "0.0.2"
STABLE = False
ComfyAPI = ComfyAPIAdapter_v0_0_2
# Create a synchronous version of the API
if TYPE_CHECKING:
from comfy_api.v0_0_2.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore
ComfyAPISync: Type[ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_2)

View File

@ -0,0 +1,20 @@
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
from PIL.Image import Image
from torch import Tensor
class ComfyAPISyncStub:
def __init__(self) -> None: ...
class ExecutionSync:
def __init__(self) -> None: ...
"""
Update the progress bar displayed in the ComfyUI interface.
This function allows custom nodes and API calls to report their progress
back to the user interface, providing visual feedback during long operations.
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
"""
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[tuple[str, Image, Union[int, None]], Image, Tensor, None] = None) -> None: ...
execution: ExecutionSync

12
comfy_api/version_list.py Normal file
View File

@ -0,0 +1,12 @@
from comfy_api.latest import ComfyAPI_latest
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
from comfy_api.internal import ComfyAPIBase
from typing import List, Type
supported_versions: List[Type[ComfyAPIBase]] = [
ComfyAPI_latest,
ComfyAPIAdapter_v0_0_2,
ComfyAPIAdapter_v0_0_1,
]

View File

@ -1,4 +1,4 @@
from typing import TypedDict, Dict, Optional
from typing import TypedDict, Dict, Optional, Tuple
from typing_extensions import override
from PIL import Image
from enum import Enum
@ -10,6 +10,7 @@ if TYPE_CHECKING:
from protocol import BinaryEventTypes
from comfy_api import feature_flags
PreviewImageTuple = Tuple[str, Image.Image, Optional[int]]
class NodeState(Enum):
Pending = "pending"
@ -52,7 +53,7 @@ class ProgressHandler(ABC):
max_value: float,
state: NodeProgressState,
prompt_id: str,
image: Optional[Image.Image] = None,
image: PreviewImageTuple | None = None,
):
"""Called when a node's progress is updated"""
pass
@ -103,7 +104,7 @@ class CLIProgressHandler(ProgressHandler):
max_value: float,
state: NodeProgressState,
prompt_id: str,
image: Optional[Image.Image] = None,
image: PreviewImageTuple | None = None,
):
# Handle case where start_handler wasn't called
if node_id not in self.progress_bars:
@ -196,7 +197,7 @@ class WebUIProgressHandler(ProgressHandler):
max_value: float,
state: NodeProgressState,
prompt_id: str,
image: Optional[Image.Image] = None,
image: PreviewImageTuple | None = None,
):
# Send progress state of all nodes
if self.registry:
@ -231,7 +232,6 @@ class WebUIProgressHandler(ProgressHandler):
if self.registry:
self._send_progress_state(prompt_id, self.registry.nodes)
class ProgressRegistry:
"""
Registry that maintains node progress state and notifies registered handlers.
@ -285,7 +285,7 @@ class ProgressRegistry:
handler.start_handler(node_id, entry, self.prompt_id)
def update_progress(
self, node_id: str, value: float, max_value: float, image: Optional[Image.Image]
self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | None = None
) -> None:
"""Update progress for a node"""
entry = self.ensure_entry(node_id)
@ -317,7 +317,7 @@ class ProgressRegistry:
handler.reset()
# Global registry instance
global_progress_registry: ProgressRegistry = None
global_progress_registry: ProgressRegistry | None = None
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
global global_progress_registry

View File

@ -8,9 +8,9 @@ import json
from typing import Optional, Literal
from fractions import Fraction
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
from comfy_api.input import ImageInput, AudioInput, VideoInput
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
from comfy_api.input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest.input import ImageInput, AudioInput, VideoInput
from comfy_api.latest.util import VideoContainer, VideoCodec, VideoComponents
from comfy_api.latest.input_impl import VideoFromFile, VideoFromComponents
from comfy.cli_args import args
class SaveWEBM:
@ -239,3 +239,4 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"GetVideoComponents": "Get Video Components",
"LoadVideo": "Load Video",
}

10
main.py
View File

@ -22,6 +22,12 @@ if __name__ == "__main__":
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
# Handle --generate-api-stubs early
if args.generate_api_stubs:
from comfy_api.generate_api_stubs import main as generate_stubs_main
generate_stubs_main()
sys.exit(0)
def apply_custom_paths():
# extra model paths
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
@ -304,10 +310,10 @@ def start_comfyui(asyncio_loop=None):
prompt_server = server.PromptServer(asyncio_loop)
hook_breaker_ac10a0.save_functions()
nodes.init_extra_nodes(
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
init_api_nodes=not args.disable_api_nodes
)
))
hook_breaker_ac10a0.restore_functions()
cuda_malloc_warning()

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import torch
import os
import sys
import json
@ -26,6 +27,8 @@ import comfy.sd
import comfy.utils
import comfy.controlnet
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
from comfy_api.internal import register_versions, ComfyAPIWithVersion
from comfy_api.version_list import supported_versions
import comfy.clip_vision
@ -2101,7 +2104,7 @@ def get_module_name(module_path: str) -> str:
return base_path
def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
async def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
module_name = get_module_name(module_path)
if os.path.isfile(module_path):
sp = os.path.splitext(module_path)
@ -2165,7 +2168,7 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
return False
def init_external_custom_nodes():
async def init_external_custom_nodes():
"""
Initializes the external custom nodes.
@ -2191,7 +2194,7 @@ def init_external_custom_nodes():
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
continue
time_before = time.perf_counter()
success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
node_import_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_import_times) > 0:
@ -2204,7 +2207,7 @@ def init_external_custom_nodes():
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
logging.info("")
def init_builtin_extra_nodes():
async def init_builtin_extra_nodes():
"""
Initializes the built-in extra nodes in ComfyUI.
@ -2288,13 +2291,13 @@ def init_builtin_extra_nodes():
import_failed = []
for node_file in extras_files:
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
if not await load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
import_failed.append(node_file)
return import_failed
def init_builtin_api_nodes():
async def init_builtin_api_nodes():
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")
api_nodes_files = [
"nodes_ideogram.py",
@ -2315,26 +2318,35 @@ def init_builtin_api_nodes():
"nodes_gemini.py",
]
if not load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
return api_nodes_files
import_failed = []
for node_file in api_nodes_files:
if not load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
if not await load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
import_failed.append(node_file)
return import_failed
async def init_public_apis():
register_versions([
ComfyAPIWithVersion(
version=getattr(v, "VERSION"),
api_class=v
) for v in supported_versions
])
def init_extra_nodes(init_custom_nodes=True, init_api_nodes=True):
import_failed = init_builtin_extra_nodes()
async def init_extra_nodes(init_custom_nodes=True, init_api_nodes=True):
await init_public_apis()
import_failed = await init_builtin_extra_nodes()
import_failed_api = []
if init_api_nodes:
import_failed_api = init_builtin_api_nodes()
import_failed_api = await init_builtin_api_nodes()
if init_custom_nodes:
init_external_custom_nodes()
await init_external_custom_nodes()
else:
logging.info("Skipping loading of custom nodes")

View File

@ -21,4 +21,4 @@ lint.select = [
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
"F",
]
exclude = ["*.ipynb"]
exclude = ["*.ipynb", "**/generated/*.pyi"]

View File

@ -4,6 +4,7 @@ from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPING
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS
from .api_test_nodes import API_TEST_NODE_CLASS_MAPPINGS, API_TEST_NODE_DISPLAY_NAME_MAPPINGS
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
@ -15,6 +16,7 @@ NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(API_TEST_NODE_CLASS_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
@ -23,4 +25,4 @@ NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(API_TEST_NODE_DISPLAY_NAME_MAPPINGS)

View File

@ -0,0 +1,78 @@
import asyncio
import time
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
from comfy_api.v0_0_2 import ComfyAPI, ComfyAPISync
api = ComfyAPI()
api_sync = ComfyAPISync()
class TestAsyncProgressUpdate(ComfyNodeABC):
"""Test node with async VALIDATE_INPUTS."""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"value": (IO.ANY, {}),
"sleep_seconds": (IO.FLOAT, {"default": 1.0}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "execute"
CATEGORY = "_for_testing/async"
async def execute(self, value, sleep_seconds):
start = time.time()
expiration = start + sleep_seconds
now = start
while now < expiration:
now = time.time()
await api.execution.set_progress(
value=(now - start) / sleep_seconds,
max_value=1.0,
)
await asyncio.sleep(0.01)
return (value,)
class TestSyncProgressUpdate(ComfyNodeABC):
"""Test node with async VALIDATE_INPUTS."""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"value": (IO.ANY, {}),
"sleep_seconds": (IO.FLOAT, {"default": 1.0}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "execute"
CATEGORY = "_for_testing/async"
def execute(self, value, sleep_seconds):
start = time.time()
expiration = start + sleep_seconds
now = start
while now < expiration:
now = time.time()
api_sync.execution.set_progress(
value=(now - start) / sleep_seconds,
max_value=1.0,
)
time.sleep(0.01)
return (value,)
API_TEST_NODE_CLASS_MAPPINGS = {
"TestAsyncProgressUpdate": TestAsyncProgressUpdate,
"TestSyncProgressUpdate": TestSyncProgressUpdate,
}
API_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestAsyncProgressUpdate": "Async Progress Update Test Node",
"TestSyncProgressUpdate": "Sync Progress Update Test Node",
}