mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 08:16:44 +00:00
Merge pull request #9037 from comfyanonymous/v3-definition-wip
V3 update - rebase on Core API PR, place v3 on latest
This commit is contained in:
commit
631916dfb2
@ -155,6 +155,7 @@ parser.add_argument("--disable-metadata", action="store_true", help="Disable sav
|
|||||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||||
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
||||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||||
|
parser.add_argument("--generate-api-stubs", action="store_true", help="Generate .pyi stub files for API sync wrappers and exit.")
|
||||||
|
|
||||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||||
|
|
||||||
|
86
comfy_api/generate_api_stubs.py
Normal file
86
comfy_api/generate_api_stubs.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Script to generate .pyi stub files for the synchronous API wrappers.
|
||||||
|
This allows generating stubs without running the full ComfyUI application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
# Add ComfyUI to path so we can import modules
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from comfy_api.internal.async_to_sync import AsyncToSyncConverter
|
||||||
|
from comfy_api.version_list import supported_versions
|
||||||
|
|
||||||
|
|
||||||
|
def generate_stubs_for_module(module_name: str) -> None:
|
||||||
|
"""Generate stub files for a specific module that exports ComfyAPI and ComfyAPISync."""
|
||||||
|
try:
|
||||||
|
# Import the module
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# Check if module has ComfyAPISync (the sync wrapper)
|
||||||
|
if hasattr(module, "ComfyAPISync"):
|
||||||
|
# Module already has a sync class
|
||||||
|
api_class = getattr(module, "ComfyAPI", None)
|
||||||
|
sync_class = getattr(module, "ComfyAPISync")
|
||||||
|
|
||||||
|
if api_class:
|
||||||
|
# Generate the stub file
|
||||||
|
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
|
||||||
|
logging.info(f"Generated stub file for {module_name}")
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"Module {module_name} has ComfyAPISync but no ComfyAPI"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif hasattr(module, "ComfyAPI"):
|
||||||
|
# Module only has async API, need to create sync wrapper first
|
||||||
|
from comfy_api.internal.async_to_sync import create_sync_class
|
||||||
|
|
||||||
|
api_class = getattr(module, "ComfyAPI")
|
||||||
|
sync_class = create_sync_class(api_class)
|
||||||
|
|
||||||
|
# Generate the stub file
|
||||||
|
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
|
||||||
|
logging.info(f"Generated stub file for {module_name}")
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"Module {module_name} does not export ComfyAPI or ComfyAPISync"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to generate stub for {module_name}: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to generate all API stub files."""
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
logging.info("Starting stub generation...")
|
||||||
|
|
||||||
|
# Dynamically get module names from supported_versions
|
||||||
|
api_modules = []
|
||||||
|
for api_class in supported_versions:
|
||||||
|
# Extract module name from the class
|
||||||
|
module_name = api_class.__module__
|
||||||
|
if module_name not in api_modules:
|
||||||
|
api_modules.append(module_name)
|
||||||
|
|
||||||
|
logging.info(f"Found {len(api_modules)} API modules: {api_modules}")
|
||||||
|
|
||||||
|
# Generate stubs for each module
|
||||||
|
for module_name in api_modules:
|
||||||
|
generate_stubs_for_module(module_name)
|
||||||
|
|
||||||
|
logging.info("Stub generation complete!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -1,8 +1,2 @@
|
|||||||
from .basic_types import ImageInput, AudioInput
|
# This file only exists for backwards compatibility.
|
||||||
from .video_types import VideoInput
|
from comfy_api.latest.input import * # noqa: F403
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ImageInput",
|
|
||||||
"AudioInput",
|
|
||||||
"VideoInput",
|
|
||||||
]
|
|
||||||
|
@ -1,20 +1,2 @@
|
|||||||
import torch
|
# This file only exists for backwards compatibility.
|
||||||
from typing import TypedDict
|
from comfy_api.latest.input.basic_types import * # noqa: F403
|
||||||
|
|
||||||
ImageInput = torch.Tensor
|
|
||||||
"""
|
|
||||||
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
|
|
||||||
"""
|
|
||||||
|
|
||||||
class AudioInput(TypedDict):
|
|
||||||
"""
|
|
||||||
TypedDict representing audio input.
|
|
||||||
"""
|
|
||||||
|
|
||||||
waveform: torch.Tensor
|
|
||||||
"""
|
|
||||||
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
|
|
||||||
"""
|
|
||||||
|
|
||||||
sample_rate: int
|
|
||||||
|
|
||||||
|
@ -1,72 +1,2 @@
|
|||||||
from __future__ import annotations
|
# This file only exists for backwards compatibility.
|
||||||
from abc import ABC, abstractmethod
|
from comfy_api.latest.input.video_types import * # noqa: F403
|
||||||
from typing import Optional, Union
|
|
||||||
import io
|
|
||||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
|
||||||
|
|
||||||
class VideoInput(ABC):
|
|
||||||
"""
|
|
||||||
Abstract base class for video input types.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_components(self) -> VideoComponents:
|
|
||||||
"""
|
|
||||||
Abstract method to get the video components (images, audio, and frame rate).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
VideoComponents containing images, audio, and frame rate
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save_to(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
|
||||||
metadata: Optional[dict] = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Abstract method to save the video input to a file.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
|
||||||
"""
|
|
||||||
Get a streamable source for the video. This allows processing without
|
|
||||||
loading the entire video into memory.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Either a file path (str) or a BytesIO object that can be opened with av.
|
|
||||||
|
|
||||||
Default implementation creates a BytesIO buffer, but subclasses should
|
|
||||||
override this for better performance when possible.
|
|
||||||
"""
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
self.save_to(buffer)
|
|
||||||
buffer.seek(0)
|
|
||||||
return buffer
|
|
||||||
|
|
||||||
# Provide a default implementation, but subclasses can provide optimized versions
|
|
||||||
# if possible.
|
|
||||||
def get_dimensions(self) -> tuple[int, int]:
|
|
||||||
"""
|
|
||||||
Returns the dimensions of the video input.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (width, height)
|
|
||||||
"""
|
|
||||||
components = self.get_components()
|
|
||||||
return components.images.shape[2], components.images.shape[1]
|
|
||||||
|
|
||||||
def get_duration(self) -> float:
|
|
||||||
"""
|
|
||||||
Returns the duration of the video in seconds.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Duration in seconds
|
|
||||||
"""
|
|
||||||
components = self.get_components()
|
|
||||||
frame_count = components.images.shape[0]
|
|
||||||
return float(frame_count / components.frame_rate)
|
|
||||||
|
@ -1,7 +1,2 @@
|
|||||||
from .video_types import VideoFromFile, VideoFromComponents
|
# This file only exists for backwards compatibility.
|
||||||
|
from comfy_api.latest.input_impl import * # noqa: F403
|
||||||
__all__ = [
|
|
||||||
# Implementations
|
|
||||||
"VideoFromFile",
|
|
||||||
"VideoFromComponents",
|
|
||||||
]
|
|
||||||
|
@ -1,312 +1,2 @@
|
|||||||
from __future__ import annotations
|
# This file only exists for backwards compatibility.
|
||||||
from av.container import InputContainer
|
from comfy_api.latest.input_impl.video_types import * # noqa: F403
|
||||||
from av.subtitles.stream import SubtitleStream
|
|
||||||
from fractions import Fraction
|
|
||||||
from typing import Optional
|
|
||||||
from comfy_api.input import AudioInput
|
|
||||||
import av
|
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from comfy_api.input import VideoInput
|
|
||||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
|
||||||
|
|
||||||
|
|
||||||
def container_to_output_format(container_format: str | None) -> str | None:
|
|
||||||
"""
|
|
||||||
A container's `format` may be a comma-separated list of formats.
|
|
||||||
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
|
|
||||||
However, writing to a file/stream with `av.open` requires a single format,
|
|
||||||
or `None` to auto-detect.
|
|
||||||
"""
|
|
||||||
if not container_format:
|
|
||||||
return None # Auto-detect
|
|
||||||
|
|
||||||
if "," not in container_format:
|
|
||||||
return container_format
|
|
||||||
|
|
||||||
formats = container_format.split(",")
|
|
||||||
return formats[0]
|
|
||||||
|
|
||||||
|
|
||||||
def get_open_write_kwargs(
|
|
||||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
|
||||||
) -> dict:
|
|
||||||
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
|
|
||||||
open_kwargs = {
|
|
||||||
"mode": "w",
|
|
||||||
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
|
|
||||||
"options": {"movflags": "use_metadata_tags"},
|
|
||||||
}
|
|
||||||
|
|
||||||
is_write_to_buffer = isinstance(dest, io.BytesIO)
|
|
||||||
if is_write_to_buffer:
|
|
||||||
# Set output format explicitly, since it cannot be inferred from file extension
|
|
||||||
if to_format == VideoContainer.AUTO:
|
|
||||||
to_format = container_format.lower()
|
|
||||||
elif isinstance(to_format, str):
|
|
||||||
to_format = to_format.lower()
|
|
||||||
open_kwargs["format"] = container_to_output_format(to_format)
|
|
||||||
|
|
||||||
return open_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
class VideoFromFile(VideoInput):
|
|
||||||
"""
|
|
||||||
Class representing video input from a file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, file: str | io.BytesIO):
|
|
||||||
"""
|
|
||||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
|
||||||
containing the file contents.
|
|
||||||
"""
|
|
||||||
self.__file = file
|
|
||||||
|
|
||||||
def get_stream_source(self) -> str | io.BytesIO:
|
|
||||||
"""
|
|
||||||
Return the underlying file source for efficient streaming.
|
|
||||||
This avoids unnecessary memory copies when the source is already a file path.
|
|
||||||
"""
|
|
||||||
if isinstance(self.__file, io.BytesIO):
|
|
||||||
self.__file.seek(0)
|
|
||||||
return self.__file
|
|
||||||
|
|
||||||
def get_dimensions(self) -> tuple[int, int]:
|
|
||||||
"""
|
|
||||||
Returns the dimensions of the video input.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (width, height)
|
|
||||||
"""
|
|
||||||
if isinstance(self.__file, io.BytesIO):
|
|
||||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
|
||||||
with av.open(self.__file, mode='r') as container:
|
|
||||||
for stream in container.streams:
|
|
||||||
if stream.type == 'video':
|
|
||||||
assert isinstance(stream, av.VideoStream)
|
|
||||||
return stream.width, stream.height
|
|
||||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
|
||||||
|
|
||||||
def get_duration(self) -> float:
|
|
||||||
"""
|
|
||||||
Returns the duration of the video in seconds.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Duration in seconds
|
|
||||||
"""
|
|
||||||
if isinstance(self.__file, io.BytesIO):
|
|
||||||
self.__file.seek(0)
|
|
||||||
with av.open(self.__file, mode="r") as container:
|
|
||||||
if container.duration is not None:
|
|
||||||
return float(container.duration / av.time_base)
|
|
||||||
|
|
||||||
# Fallback: calculate from frame count and frame rate
|
|
||||||
video_stream = next(
|
|
||||||
(s for s in container.streams if s.type == "video"), None
|
|
||||||
)
|
|
||||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
|
||||||
return float(video_stream.frames / video_stream.average_rate)
|
|
||||||
|
|
||||||
# Last resort: decode frames to count them
|
|
||||||
if video_stream and video_stream.average_rate:
|
|
||||||
frame_count = 0
|
|
||||||
container.seek(0)
|
|
||||||
for packet in container.demux(video_stream):
|
|
||||||
for _ in packet.decode():
|
|
||||||
frame_count += 1
|
|
||||||
if frame_count > 0:
|
|
||||||
return float(frame_count / video_stream.average_rate)
|
|
||||||
|
|
||||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
|
||||||
|
|
||||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
|
||||||
# Get video frames
|
|
||||||
frames = []
|
|
||||||
for frame in container.decode(video=0):
|
|
||||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
|
||||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
|
||||||
frames.append(img)
|
|
||||||
|
|
||||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
|
||||||
|
|
||||||
# Get frame rate
|
|
||||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
|
||||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
|
||||||
|
|
||||||
# Get audio if available
|
|
||||||
audio = None
|
|
||||||
try:
|
|
||||||
container.seek(0) # Reset the container to the beginning
|
|
||||||
for stream in container.streams:
|
|
||||||
if stream.type != 'audio':
|
|
||||||
continue
|
|
||||||
assert isinstance(stream, av.AudioStream)
|
|
||||||
audio_frames = []
|
|
||||||
for packet in container.demux(stream):
|
|
||||||
for frame in packet.decode():
|
|
||||||
assert isinstance(frame, av.AudioFrame)
|
|
||||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
|
||||||
if len(audio_frames) > 0:
|
|
||||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
|
||||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
|
||||||
audio = AudioInput({
|
|
||||||
"waveform": audio_tensor,
|
|
||||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
|
||||||
})
|
|
||||||
except StopIteration:
|
|
||||||
pass # No audio stream
|
|
||||||
|
|
||||||
metadata = container.metadata
|
|
||||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
|
||||||
|
|
||||||
def get_components(self) -> VideoComponents:
|
|
||||||
if isinstance(self.__file, io.BytesIO):
|
|
||||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
|
||||||
with av.open(self.__file, mode='r') as container:
|
|
||||||
return self.get_components_internal(container)
|
|
||||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
|
||||||
|
|
||||||
def save_to(
|
|
||||||
self,
|
|
||||||
path: str | io.BytesIO,
|
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
|
||||||
metadata: Optional[dict] = None
|
|
||||||
):
|
|
||||||
if isinstance(self.__file, io.BytesIO):
|
|
||||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
|
||||||
with av.open(self.__file, mode='r') as container:
|
|
||||||
container_format = container.format.name
|
|
||||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
|
||||||
reuse_streams = True
|
|
||||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
|
||||||
reuse_streams = False
|
|
||||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
|
||||||
reuse_streams = False
|
|
||||||
|
|
||||||
if not reuse_streams:
|
|
||||||
components = self.get_components_internal(container)
|
|
||||||
video = VideoFromComponents(components)
|
|
||||||
return video.save_to(
|
|
||||||
path,
|
|
||||||
format=format,
|
|
||||||
codec=codec,
|
|
||||||
metadata=metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
streams = container.streams
|
|
||||||
|
|
||||||
open_kwargs = get_open_write_kwargs(path, container_format, format)
|
|
||||||
with av.open(path, **open_kwargs) as output_container:
|
|
||||||
# Copy over the original metadata
|
|
||||||
for key, value in container.metadata.items():
|
|
||||||
if metadata is None or key not in metadata:
|
|
||||||
output_container.metadata[key] = value
|
|
||||||
|
|
||||||
# Add our new metadata
|
|
||||||
if metadata is not None:
|
|
||||||
for key, value in metadata.items():
|
|
||||||
if isinstance(value, str):
|
|
||||||
output_container.metadata[key] = value
|
|
||||||
else:
|
|
||||||
output_container.metadata[key] = json.dumps(value)
|
|
||||||
|
|
||||||
# Add streams to the new container
|
|
||||||
stream_map = {}
|
|
||||||
for stream in streams:
|
|
||||||
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
|
|
||||||
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
|
|
||||||
stream_map[stream] = out_stream
|
|
||||||
|
|
||||||
# Write packets to the new container
|
|
||||||
for packet in container.demux():
|
|
||||||
if packet.stream in stream_map and packet.dts is not None:
|
|
||||||
packet.stream = stream_map[packet.stream]
|
|
||||||
output_container.mux(packet)
|
|
||||||
|
|
||||||
class VideoFromComponents(VideoInput):
|
|
||||||
"""
|
|
||||||
Class representing video input from tensors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, components: VideoComponents):
|
|
||||||
self.__components = components
|
|
||||||
|
|
||||||
def get_components(self) -> VideoComponents:
|
|
||||||
return VideoComponents(
|
|
||||||
images=self.__components.images,
|
|
||||||
audio=self.__components.audio,
|
|
||||||
frame_rate=self.__components.frame_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
def save_to(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
|
||||||
metadata: Optional[dict] = None
|
|
||||||
):
|
|
||||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
|
||||||
raise ValueError("Only MP4 format is supported for now")
|
|
||||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
|
||||||
raise ValueError("Only H264 codec is supported for now")
|
|
||||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
|
||||||
# Add metadata before writing any streams
|
|
||||||
if metadata is not None:
|
|
||||||
for key, value in metadata.items():
|
|
||||||
output.metadata[key] = json.dumps(value)
|
|
||||||
|
|
||||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
|
||||||
# Create a video stream
|
|
||||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
|
||||||
video_stream.width = self.__components.images.shape[2]
|
|
||||||
video_stream.height = self.__components.images.shape[1]
|
|
||||||
video_stream.pix_fmt = 'yuv420p'
|
|
||||||
|
|
||||||
# Create an audio stream
|
|
||||||
audio_sample_rate = 1
|
|
||||||
audio_stream: Optional[av.AudioStream] = None
|
|
||||||
if self.__components.audio:
|
|
||||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
|
||||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
|
||||||
audio_stream.sample_rate = audio_sample_rate
|
|
||||||
audio_stream.format = 'fltp'
|
|
||||||
|
|
||||||
# Encode video
|
|
||||||
for i, frame in enumerate(self.__components.images):
|
|
||||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
|
||||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
|
||||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
|
||||||
packet = video_stream.encode(frame)
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
# Flush video
|
|
||||||
packet = video_stream.encode(None)
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
if audio_stream and self.__components.audio:
|
|
||||||
# Encode audio
|
|
||||||
samples_per_frame = int(audio_sample_rate / frame_rate)
|
|
||||||
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
|
|
||||||
for i in range(num_frames):
|
|
||||||
start = i * samples_per_frame
|
|
||||||
end = start + samples_per_frame
|
|
||||||
# TODO(Feature) - Add support for stereo audio
|
|
||||||
chunk = (
|
|
||||||
self.__components.audio["waveform"][0, 0, start:end]
|
|
||||||
.unsqueeze(0)
|
|
||||||
.contiguous()
|
|
||||||
.numpy()
|
|
||||||
)
|
|
||||||
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
|
||||||
audio_frame.sample_rate = audio_sample_rate
|
|
||||||
audio_frame.pts = i * samples_per_frame
|
|
||||||
for packet in audio_stream.encode(audio_frame):
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
# Flush audio
|
|
||||||
for packet in audio_stream.encode(None):
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
|
@ -1,3 +1,11 @@
|
|||||||
|
# Internal infrastructure for ComfyAPI
|
||||||
|
from .api_registry import (
|
||||||
|
ComfyAPIBase as ComfyAPIBase,
|
||||||
|
ComfyAPIWithVersion as ComfyAPIWithVersion,
|
||||||
|
register_versions as register_versions,
|
||||||
|
get_all_versions as get_all_versions,
|
||||||
|
)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
39
comfy_api/internal/api_registry.py
Normal file
39
comfy_api/internal/api_registry.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from typing import Type, List, NamedTuple
|
||||||
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
|
from packaging import version as packaging_version
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyAPIBase(ProxiedSingleton):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyAPIWithVersion(NamedTuple):
|
||||||
|
version: str
|
||||||
|
api_class: Type[ComfyAPIBase]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_version(version_str: str) -> packaging_version.Version:
|
||||||
|
"""
|
||||||
|
Parses a version string into a packaging_version.Version object.
|
||||||
|
Raises ValueError if the version string is invalid.
|
||||||
|
"""
|
||||||
|
if version_str == "latest":
|
||||||
|
return packaging_version.parse("9999999.9999999.9999999")
|
||||||
|
return packaging_version.parse(version_str)
|
||||||
|
|
||||||
|
|
||||||
|
registered_versions: List[ComfyAPIWithVersion] = []
|
||||||
|
|
||||||
|
|
||||||
|
def register_versions(versions: List[ComfyAPIWithVersion]):
|
||||||
|
versions.sort(key=lambda x: parse_version(x.version))
|
||||||
|
global registered_versions
|
||||||
|
registered_versions = versions
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_versions() -> List[ComfyAPIWithVersion]:
|
||||||
|
"""
|
||||||
|
Returns a list of all registered ComfyAPI versions.
|
||||||
|
"""
|
||||||
|
return registered_versions
|
942
comfy_api/internal/async_to_sync.py
Normal file
942
comfy_api/internal/async_to_sync.py
Normal file
@ -0,0 +1,942 @@
|
|||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
import contextvars
|
||||||
|
import functools
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import textwrap
|
||||||
|
import threading
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Type, get_origin, get_args
|
||||||
|
|
||||||
|
|
||||||
|
class TypeTracker:
|
||||||
|
"""Tracks types discovered during stub generation for automatic import generation."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.discovered_types = {} # type_name -> (module, qualname)
|
||||||
|
self.builtin_types = {
|
||||||
|
"Any",
|
||||||
|
"Dict",
|
||||||
|
"List",
|
||||||
|
"Optional",
|
||||||
|
"Tuple",
|
||||||
|
"Union",
|
||||||
|
"Set",
|
||||||
|
"Sequence",
|
||||||
|
"cast",
|
||||||
|
"NamedTuple",
|
||||||
|
"str",
|
||||||
|
"int",
|
||||||
|
"float",
|
||||||
|
"bool",
|
||||||
|
"None",
|
||||||
|
"bytes",
|
||||||
|
"object",
|
||||||
|
"type",
|
||||||
|
"dict",
|
||||||
|
"list",
|
||||||
|
"tuple",
|
||||||
|
"set",
|
||||||
|
}
|
||||||
|
self.already_imported = (
|
||||||
|
set()
|
||||||
|
) # Track types already imported to avoid duplicates
|
||||||
|
|
||||||
|
def track_type(self, annotation):
|
||||||
|
"""Track a type annotation and record its module/import info."""
|
||||||
|
if annotation is None or annotation is type(None):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip builtins and typing module types we already import
|
||||||
|
type_name = getattr(annotation, "__name__", None)
|
||||||
|
if type_name and (
|
||||||
|
type_name in self.builtin_types or type_name in self.already_imported
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get module and qualname
|
||||||
|
module = getattr(annotation, "__module__", None)
|
||||||
|
qualname = getattr(annotation, "__qualname__", type_name or "")
|
||||||
|
|
||||||
|
# Skip types from typing module (they're already imported)
|
||||||
|
if module == "typing":
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip UnionType and GenericAlias from types module as they're handled specially
|
||||||
|
if module == "types" and type_name in ("UnionType", "GenericAlias"):
|
||||||
|
return
|
||||||
|
|
||||||
|
if module and module not in ["builtins", "__main__"]:
|
||||||
|
# Store the type info
|
||||||
|
if type_name:
|
||||||
|
self.discovered_types[type_name] = (module, qualname)
|
||||||
|
|
||||||
|
def get_imports(self, main_module_name: str) -> list[str]:
|
||||||
|
"""Generate import statements for all discovered types."""
|
||||||
|
imports = []
|
||||||
|
imports_by_module = {}
|
||||||
|
|
||||||
|
for type_name, (module, qualname) in sorted(self.discovered_types.items()):
|
||||||
|
# Skip types from the main module (they're already imported)
|
||||||
|
if main_module_name and module == main_module_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if module not in imports_by_module:
|
||||||
|
imports_by_module[module] = []
|
||||||
|
if type_name not in imports_by_module[module]: # Avoid duplicates
|
||||||
|
imports_by_module[module].append(type_name)
|
||||||
|
|
||||||
|
# Generate import statements
|
||||||
|
for module, types in sorted(imports_by_module.items()):
|
||||||
|
if len(types) == 1:
|
||||||
|
imports.append(f"from {module} import {types[0]}")
|
||||||
|
else:
|
||||||
|
imports.append(f"from {module} import {', '.join(sorted(set(types)))}")
|
||||||
|
|
||||||
|
return imports
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncToSyncConverter:
|
||||||
|
"""
|
||||||
|
Provides utilities to convert async classes to sync classes with proper type hints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None
|
||||||
|
_thread_pool_lock = threading.Lock()
|
||||||
|
_thread_pool_initialized = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor:
|
||||||
|
"""Get or create the shared thread pool with proper thread-safe initialization."""
|
||||||
|
# Fast path - check if already initialized without acquiring lock
|
||||||
|
if cls._thread_pool_initialized:
|
||||||
|
assert cls._thread_pool is not None, "Thread pool should be initialized"
|
||||||
|
return cls._thread_pool
|
||||||
|
|
||||||
|
# Slow path - acquire lock and create pool if needed
|
||||||
|
with cls._thread_pool_lock:
|
||||||
|
if not cls._thread_pool_initialized:
|
||||||
|
cls._thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||||
|
max_workers=max_workers, thread_name_prefix="async_to_sync_"
|
||||||
|
)
|
||||||
|
cls._thread_pool_initialized = True
|
||||||
|
|
||||||
|
# This should never be None at this point, but add assertion for type checker
|
||||||
|
assert cls._thread_pool is not None
|
||||||
|
return cls._thread_pool
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def run_async_in_thread(cls, coro_func, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Run an async function in a separate thread from the thread pool.
|
||||||
|
Blocks until the async function completes.
|
||||||
|
Properly propagates contextvars between threads and manages event loops.
|
||||||
|
"""
|
||||||
|
# Capture current context - this includes all context variables
|
||||||
|
context = contextvars.copy_context()
|
||||||
|
|
||||||
|
# Store the result and any exception that occurs
|
||||||
|
result_container: dict = {"result": None, "exception": None}
|
||||||
|
|
||||||
|
# Function that runs in the thread pool
|
||||||
|
def run_in_thread():
|
||||||
|
# Create new event loop for this thread
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create the coroutine within the context
|
||||||
|
async def run_with_context():
|
||||||
|
# The coroutine function might access context variables
|
||||||
|
return await coro_func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Run the coroutine with the captured context
|
||||||
|
# This ensures all context variables are available in the async function
|
||||||
|
result = context.run(loop.run_until_complete, run_with_context())
|
||||||
|
result_container["result"] = result
|
||||||
|
except Exception as e:
|
||||||
|
# Store the exception to re-raise in the calling thread
|
||||||
|
result_container["exception"] = e
|
||||||
|
finally:
|
||||||
|
# Ensure event loop is properly closed to prevent warnings
|
||||||
|
try:
|
||||||
|
# Cancel any remaining tasks
|
||||||
|
pending = asyncio.all_tasks(loop)
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# Run the loop briefly to handle cancellations
|
||||||
|
if pending:
|
||||||
|
loop.run_until_complete(
|
||||||
|
asyncio.gather(*pending, return_exceptions=True)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Ignore errors during cleanup
|
||||||
|
|
||||||
|
# Close the event loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
# Clear the event loop from the thread
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
# Submit to thread pool and wait for result
|
||||||
|
thread_pool = cls.get_thread_pool()
|
||||||
|
future = thread_pool.submit(run_in_thread)
|
||||||
|
future.result() # Wait for completion
|
||||||
|
|
||||||
|
# Re-raise any exception that occurred in the thread
|
||||||
|
if result_container["exception"] is not None:
|
||||||
|
raise result_container["exception"]
|
||||||
|
|
||||||
|
return result_container["result"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
|
||||||
|
"""
|
||||||
|
Creates a new class with synchronous versions of all async methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
async_class: The async class to convert
|
||||||
|
thread_pool_size: Size of thread pool to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new class with sync versions of all async methods
|
||||||
|
"""
|
||||||
|
sync_class_name = "ComfyAPISyncStub"
|
||||||
|
cls.get_thread_pool(thread_pool_size)
|
||||||
|
|
||||||
|
# Create a proper class with docstrings and proper base classes
|
||||||
|
sync_class_dict = {
|
||||||
|
"__doc__": async_class.__doc__,
|
||||||
|
"__module__": async_class.__module__,
|
||||||
|
"__qualname__": sync_class_name,
|
||||||
|
"__orig_class__": async_class, # Store original class for typing references
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create __init__ method
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._async_instance = async_class(*args, **kwargs)
|
||||||
|
|
||||||
|
# Handle annotated class attributes (like execution: Execution)
|
||||||
|
# Get all annotations from the class hierarchy
|
||||||
|
all_annotations = {}
|
||||||
|
for base_class in reversed(inspect.getmro(async_class)):
|
||||||
|
if hasattr(base_class, "__annotations__"):
|
||||||
|
all_annotations.update(base_class.__annotations__)
|
||||||
|
|
||||||
|
# For each annotated attribute, check if it needs to be created or wrapped
|
||||||
|
for attr_name, attr_type in all_annotations.items():
|
||||||
|
if hasattr(self._async_instance, attr_name):
|
||||||
|
# Attribute exists on the instance
|
||||||
|
attr = getattr(self._async_instance, attr_name)
|
||||||
|
# Check if this attribute needs a sync wrapper
|
||||||
|
if hasattr(attr, "__class__"):
|
||||||
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
|
|
||||||
|
if isinstance(attr, ProxiedSingleton):
|
||||||
|
# Create a sync version of this attribute
|
||||||
|
try:
|
||||||
|
sync_attr_class = cls.create_sync_class(attr.__class__)
|
||||||
|
# Create instance of the sync wrapper with the async instance
|
||||||
|
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||||
|
sync_attr._async_instance = attr
|
||||||
|
setattr(self, attr_name, sync_attr)
|
||||||
|
except Exception:
|
||||||
|
# If we can't create a sync version, keep the original
|
||||||
|
setattr(self, attr_name, attr)
|
||||||
|
else:
|
||||||
|
# Not async, just copy the reference
|
||||||
|
setattr(self, attr_name, attr)
|
||||||
|
else:
|
||||||
|
# Attribute doesn't exist, but is annotated - create it
|
||||||
|
# This handles cases like execution: Execution
|
||||||
|
if isinstance(attr_type, type):
|
||||||
|
# Check if the type is defined as an inner class
|
||||||
|
if hasattr(async_class, attr_type.__name__):
|
||||||
|
inner_class = getattr(async_class, attr_type.__name__)
|
||||||
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
|
|
||||||
|
# Create an instance of the inner class
|
||||||
|
try:
|
||||||
|
# For ProxiedSingleton classes, get or create the singleton instance
|
||||||
|
if issubclass(inner_class, ProxiedSingleton):
|
||||||
|
async_instance = inner_class.get_instance()
|
||||||
|
else:
|
||||||
|
async_instance = inner_class()
|
||||||
|
|
||||||
|
# Create sync wrapper
|
||||||
|
sync_attr_class = cls.create_sync_class(inner_class)
|
||||||
|
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||||
|
sync_attr._async_instance = async_instance
|
||||||
|
setattr(self, attr_name, sync_attr)
|
||||||
|
# Also set on the async instance for consistency
|
||||||
|
setattr(self._async_instance, attr_name, async_instance)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(
|
||||||
|
f"Failed to create instance for {attr_name}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle other instance attributes that might not be annotated
|
||||||
|
for name, attr in inspect.getmembers(self._async_instance):
|
||||||
|
if name.startswith("_") or hasattr(self, name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If attribute is an instance of a class, and that class is defined in the original class
|
||||||
|
# we need to check if it needs a sync wrapper
|
||||||
|
if isinstance(attr, object) and not isinstance(
|
||||||
|
attr, (str, int, float, bool, list, dict, tuple)
|
||||||
|
):
|
||||||
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
|
|
||||||
|
if isinstance(attr, ProxiedSingleton):
|
||||||
|
# Create a sync version of this nested class
|
||||||
|
try:
|
||||||
|
sync_attr_class = cls.create_sync_class(attr.__class__)
|
||||||
|
# Create instance of the sync wrapper with the async instance
|
||||||
|
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||||
|
sync_attr._async_instance = attr
|
||||||
|
setattr(self, name, sync_attr)
|
||||||
|
except Exception:
|
||||||
|
# If we can't create a sync version, keep the original
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
sync_class_dict["__init__"] = __init__
|
||||||
|
|
||||||
|
# Process methods from the async class
|
||||||
|
for name, method in inspect.getmembers(
|
||||||
|
async_class, predicate=inspect.isfunction
|
||||||
|
):
|
||||||
|
if name.startswith("_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract the actual return type from a coroutine
|
||||||
|
if inspect.iscoroutinefunction(method):
|
||||||
|
# Create sync version of async method with proper signature
|
||||||
|
@functools.wraps(method)
|
||||||
|
def sync_method(self, *args, _method_name=name, **kwargs):
|
||||||
|
async_method = getattr(self._async_instance, _method_name)
|
||||||
|
return AsyncToSyncConverter.run_async_in_thread(
|
||||||
|
async_method, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to the class dict
|
||||||
|
sync_class_dict[name] = sync_method
|
||||||
|
else:
|
||||||
|
# For regular methods, create a proxy method
|
||||||
|
@functools.wraps(method)
|
||||||
|
def proxy_method(self, *args, _method_name=name, **kwargs):
|
||||||
|
method = getattr(self._async_instance, _method_name)
|
||||||
|
return method(*args, **kwargs)
|
||||||
|
|
||||||
|
# Add to the class dict
|
||||||
|
sync_class_dict[name] = proxy_method
|
||||||
|
|
||||||
|
# Handle property access
|
||||||
|
for name, prop in inspect.getmembers(
|
||||||
|
async_class, lambda x: isinstance(x, property)
|
||||||
|
):
|
||||||
|
|
||||||
|
def make_property(name, prop_obj):
|
||||||
|
def getter(self):
|
||||||
|
value = getattr(self._async_instance, name)
|
||||||
|
if inspect.iscoroutinefunction(value):
|
||||||
|
|
||||||
|
def sync_fn(*args, **kwargs):
|
||||||
|
return AsyncToSyncConverter.run_async_in_thread(
|
||||||
|
value, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return sync_fn
|
||||||
|
return value
|
||||||
|
|
||||||
|
def setter(self, value):
|
||||||
|
setattr(self._async_instance, name, value)
|
||||||
|
|
||||||
|
return property(getter, setter if prop_obj.fset else None)
|
||||||
|
|
||||||
|
sync_class_dict[name] = make_property(name, prop)
|
||||||
|
|
||||||
|
# Create the class
|
||||||
|
sync_class = type(sync_class_name, (object,), sync_class_dict)
|
||||||
|
|
||||||
|
return sync_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _format_type_annotation(
|
||||||
|
cls, annotation, type_tracker: Optional[TypeTracker] = None
|
||||||
|
) -> str:
|
||||||
|
"""Convert a type annotation to its string representation for stub files."""
|
||||||
|
if (
|
||||||
|
annotation is inspect.Parameter.empty
|
||||||
|
or annotation is inspect.Signature.empty
|
||||||
|
):
|
||||||
|
return "Any"
|
||||||
|
|
||||||
|
# Handle None type
|
||||||
|
if annotation is type(None):
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
# Track the type if we have a tracker
|
||||||
|
if type_tracker:
|
||||||
|
type_tracker.track_type(annotation)
|
||||||
|
|
||||||
|
# Try using typing.get_origin/get_args for Python 3.8+
|
||||||
|
try:
|
||||||
|
origin = get_origin(annotation)
|
||||||
|
args = get_args(annotation)
|
||||||
|
|
||||||
|
if origin is not None:
|
||||||
|
# Track the origin type
|
||||||
|
if type_tracker:
|
||||||
|
type_tracker.track_type(origin)
|
||||||
|
|
||||||
|
# Get the origin name
|
||||||
|
origin_name = getattr(origin, "__name__", str(origin))
|
||||||
|
if "." in origin_name:
|
||||||
|
origin_name = origin_name.split(".")[-1]
|
||||||
|
|
||||||
|
# Special handling for types.UnionType (Python 3.10+ pipe operator)
|
||||||
|
if origin_name == "UnionType":
|
||||||
|
origin_name = "Union"
|
||||||
|
|
||||||
|
# Format arguments recursively
|
||||||
|
if args:
|
||||||
|
formatted_args = [
|
||||||
|
cls._format_type_annotation(arg, type_tracker) for arg in args
|
||||||
|
]
|
||||||
|
return f"{origin_name}[{', '.join(formatted_args)}]"
|
||||||
|
else:
|
||||||
|
return origin_name
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
# Fallback for older Python versions or non-generic types
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Handle generic types the old way for compatibility
|
||||||
|
if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"):
|
||||||
|
origin = annotation.__origin__
|
||||||
|
origin_name = (
|
||||||
|
origin.__name__
|
||||||
|
if hasattr(origin, "__name__")
|
||||||
|
else str(origin).split("'")[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format each type argument
|
||||||
|
args = []
|
||||||
|
for arg in annotation.__args__:
|
||||||
|
args.append(cls._format_type_annotation(arg, type_tracker))
|
||||||
|
|
||||||
|
return f"{origin_name}[{', '.join(args)}]"
|
||||||
|
|
||||||
|
# Handle regular types with __name__
|
||||||
|
if hasattr(annotation, "__name__"):
|
||||||
|
return annotation.__name__
|
||||||
|
|
||||||
|
# Handle special module types (like types from typing module)
|
||||||
|
if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"):
|
||||||
|
# For types like typing.Literal, typing.TypedDict, etc.
|
||||||
|
return annotation.__qualname__
|
||||||
|
|
||||||
|
# Last resort: string conversion with cleanup
|
||||||
|
type_str = str(annotation)
|
||||||
|
|
||||||
|
# Clean up common patterns more robustly
|
||||||
|
if type_str.startswith("<class '") and type_str.endswith("'>"):
|
||||||
|
type_str = type_str[8:-2] # Remove "<class '" and "'>"
|
||||||
|
|
||||||
|
# Remove module prefixes for common modules
|
||||||
|
for prefix in ["typing.", "builtins.", "types."]:
|
||||||
|
if type_str.startswith(prefix):
|
||||||
|
type_str = type_str[len(prefix) :]
|
||||||
|
|
||||||
|
# Handle special cases
|
||||||
|
if type_str in ("_empty", "inspect._empty"):
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
# Fix NoneType (this should rarely be needed now)
|
||||||
|
if type_str == "NoneType":
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
return type_str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_coroutine_return_type(cls, annotation):
|
||||||
|
"""Extract the actual return type from a Coroutine annotation."""
|
||||||
|
if hasattr(annotation, "__args__") and len(annotation.__args__) > 2:
|
||||||
|
# Coroutine[Any, Any, ReturnType] -> extract ReturnType
|
||||||
|
return annotation.__args__[2]
|
||||||
|
return annotation
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _format_parameter_default(cls, default_value) -> str:
|
||||||
|
"""Format a parameter's default value for stub files."""
|
||||||
|
if default_value is inspect.Parameter.empty:
|
||||||
|
return ""
|
||||||
|
elif default_value is None:
|
||||||
|
return " = None"
|
||||||
|
elif isinstance(default_value, bool):
|
||||||
|
return f" = {default_value}"
|
||||||
|
elif default_value == {}:
|
||||||
|
return " = {}"
|
||||||
|
elif default_value == []:
|
||||||
|
return " = []"
|
||||||
|
else:
|
||||||
|
return f" = {default_value}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _format_method_parameters(
|
||||||
|
cls,
|
||||||
|
sig: inspect.Signature,
|
||||||
|
skip_self: bool = True,
|
||||||
|
type_tracker: Optional[TypeTracker] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Format method parameters for stub files."""
|
||||||
|
params = []
|
||||||
|
|
||||||
|
for i, (param_name, param) in enumerate(sig.parameters.items()):
|
||||||
|
if i == 0 and param_name == "self" and skip_self:
|
||||||
|
params.append("self")
|
||||||
|
else:
|
||||||
|
# Get type annotation
|
||||||
|
type_str = cls._format_type_annotation(param.annotation, type_tracker)
|
||||||
|
|
||||||
|
# Get default value
|
||||||
|
default_str = cls._format_parameter_default(param.default)
|
||||||
|
|
||||||
|
# Combine parameter parts
|
||||||
|
if param.annotation is inspect.Parameter.empty:
|
||||||
|
params.append(f"{param_name}: Any{default_str}")
|
||||||
|
else:
|
||||||
|
params.append(f"{param_name}: {type_str}{default_str}")
|
||||||
|
|
||||||
|
return ", ".join(params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _generate_method_signature(
|
||||||
|
cls,
|
||||||
|
method_name: str,
|
||||||
|
method,
|
||||||
|
is_async: bool = False,
|
||||||
|
type_tracker: Optional[TypeTracker] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Generate a complete method signature for stub files."""
|
||||||
|
sig = inspect.signature(method)
|
||||||
|
|
||||||
|
# For async methods, extract the actual return type
|
||||||
|
return_annotation = sig.return_annotation
|
||||||
|
if is_async and inspect.iscoroutinefunction(method):
|
||||||
|
return_annotation = cls._extract_coroutine_return_type(return_annotation)
|
||||||
|
|
||||||
|
# Format parameters
|
||||||
|
params_str = cls._format_method_parameters(sig, type_tracker=type_tracker)
|
||||||
|
|
||||||
|
# Format return type
|
||||||
|
return_type = cls._format_type_annotation(return_annotation, type_tracker)
|
||||||
|
if return_annotation is inspect.Signature.empty:
|
||||||
|
return_type = "None"
|
||||||
|
|
||||||
|
return f"def {method_name}({params_str}) -> {return_type}: ..."
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _generate_imports(
|
||||||
|
cls, async_class: Type, type_tracker: TypeTracker
|
||||||
|
) -> list[str]:
|
||||||
|
"""Generate import statements for the stub file."""
|
||||||
|
imports = []
|
||||||
|
|
||||||
|
# Add standard typing imports
|
||||||
|
imports.append(
|
||||||
|
"from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add imports from the original module
|
||||||
|
if async_class.__module__ != "builtins":
|
||||||
|
module = inspect.getmodule(async_class)
|
||||||
|
additional_types = []
|
||||||
|
|
||||||
|
if module:
|
||||||
|
for name, obj in sorted(inspect.getmembers(module)):
|
||||||
|
if isinstance(obj, type):
|
||||||
|
# Check for NamedTuple
|
||||||
|
if issubclass(obj, tuple) and hasattr(obj, "_fields"):
|
||||||
|
additional_types.append(name)
|
||||||
|
# Mark as already imported
|
||||||
|
type_tracker.already_imported.add(name)
|
||||||
|
# Check for Enum
|
||||||
|
elif issubclass(obj, Enum) and name != "Enum":
|
||||||
|
additional_types.append(name)
|
||||||
|
# Mark as already imported
|
||||||
|
type_tracker.already_imported.add(name)
|
||||||
|
|
||||||
|
if additional_types:
|
||||||
|
type_imports = ", ".join([async_class.__name__] + additional_types)
|
||||||
|
imports.append(f"from {async_class.__module__} import {type_imports}")
|
||||||
|
else:
|
||||||
|
imports.append(
|
||||||
|
f"from {async_class.__module__} import {async_class.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add imports for all discovered types
|
||||||
|
# Pass the main module name to avoid duplicate imports
|
||||||
|
imports.extend(
|
||||||
|
type_tracker.get_imports(main_module_name=async_class.__module__)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add base module import if needed
|
||||||
|
if hasattr(inspect.getmodule(async_class), "__name__"):
|
||||||
|
module_name = inspect.getmodule(async_class).__name__
|
||||||
|
if "." in module_name:
|
||||||
|
base_module = module_name.split(".")[0]
|
||||||
|
# Only add if not already importing from it
|
||||||
|
if not any(imp.startswith(f"from {base_module}") for imp in imports):
|
||||||
|
imports.append(f"import {base_module}")
|
||||||
|
|
||||||
|
return imports
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
|
||||||
|
"""Extract class attributes that are classes themselves."""
|
||||||
|
class_attributes = []
|
||||||
|
|
||||||
|
# Look for class attributes that are classes
|
||||||
|
for name, attr in sorted(inspect.getmembers(async_class)):
|
||||||
|
if isinstance(attr, type) and not name.startswith("_"):
|
||||||
|
class_attributes.append((name, attr))
|
||||||
|
elif (
|
||||||
|
hasattr(async_class, "__annotations__")
|
||||||
|
and name in async_class.__annotations__
|
||||||
|
):
|
||||||
|
annotation = async_class.__annotations__[name]
|
||||||
|
if isinstance(annotation, type):
|
||||||
|
class_attributes.append((name, annotation))
|
||||||
|
|
||||||
|
return class_attributes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _generate_inner_class_stub(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
attr: Type,
|
||||||
|
indent: str = " ",
|
||||||
|
type_tracker: Optional[TypeTracker] = None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Generate stub for an inner class."""
|
||||||
|
stub_lines = []
|
||||||
|
stub_lines.append(f"{indent}class {name}Sync:")
|
||||||
|
|
||||||
|
# Add docstring if available
|
||||||
|
if hasattr(attr, "__doc__") and attr.__doc__:
|
||||||
|
stub_lines.extend(
|
||||||
|
cls._format_docstring_for_stub(attr.__doc__, f"{indent} ")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add __init__ if it exists
|
||||||
|
if hasattr(attr, "__init__"):
|
||||||
|
try:
|
||||||
|
init_method = getattr(attr, "__init__")
|
||||||
|
init_sig = inspect.signature(init_method)
|
||||||
|
# Format parameters
|
||||||
|
params_str = cls._format_method_parameters(
|
||||||
|
init_sig, type_tracker=type_tracker
|
||||||
|
)
|
||||||
|
# Add __init__ docstring if available (before the method)
|
||||||
|
if hasattr(init_method, "__doc__") and init_method.__doc__:
|
||||||
|
stub_lines.extend(
|
||||||
|
cls._format_docstring_for_stub(
|
||||||
|
init_method.__doc__, f"{indent} "
|
||||||
|
)
|
||||||
|
)
|
||||||
|
stub_lines.append(
|
||||||
|
f"{indent} def __init__({params_str}) -> None: ..."
|
||||||
|
)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
stub_lines.append(
|
||||||
|
f"{indent} def __init__(self, *args, **kwargs) -> None: ..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add methods to the inner class
|
||||||
|
has_methods = False
|
||||||
|
for method_name, method in sorted(
|
||||||
|
inspect.getmembers(attr, predicate=inspect.isfunction)
|
||||||
|
):
|
||||||
|
if method_name.startswith("_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
has_methods = True
|
||||||
|
try:
|
||||||
|
# Add method docstring if available (before the method signature)
|
||||||
|
if method.__doc__:
|
||||||
|
stub_lines.extend(
|
||||||
|
cls._format_docstring_for_stub(method.__doc__, f"{indent} ")
|
||||||
|
)
|
||||||
|
|
||||||
|
method_sig = cls._generate_method_signature(
|
||||||
|
method_name, method, is_async=True, type_tracker=type_tracker
|
||||||
|
)
|
||||||
|
stub_lines.append(f"{indent} {method_sig}")
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
stub_lines.append(
|
||||||
|
f"{indent} def {method_name}(self, *args, **kwargs): ..."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_methods:
|
||||||
|
stub_lines.append(f"{indent} pass")
|
||||||
|
|
||||||
|
return stub_lines
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _format_docstring_for_stub(
|
||||||
|
cls, docstring: str, indent: str = " "
|
||||||
|
) -> list[str]:
|
||||||
|
"""Format a docstring for inclusion in a stub file with proper indentation."""
|
||||||
|
if not docstring:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# First, dedent the docstring to remove any existing indentation
|
||||||
|
dedented = textwrap.dedent(docstring).strip()
|
||||||
|
|
||||||
|
# Split into lines
|
||||||
|
lines = dedented.split("\n")
|
||||||
|
|
||||||
|
# Build the properly indented docstring
|
||||||
|
result = []
|
||||||
|
result.append(f'{indent}"""')
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if line.strip(): # Non-empty line
|
||||||
|
result.append(f"{indent}{line}")
|
||||||
|
else: # Empty line
|
||||||
|
result.append("")
|
||||||
|
|
||||||
|
result.append(f'{indent}"""')
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]:
|
||||||
|
"""Post-process stub content to fix any remaining issues."""
|
||||||
|
processed = []
|
||||||
|
|
||||||
|
for line in stub_content:
|
||||||
|
# Skip processing imports
|
||||||
|
if line.startswith(("from ", "import ")):
|
||||||
|
processed.append(line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fix method signatures missing return types
|
||||||
|
if (
|
||||||
|
line.strip().startswith("def ")
|
||||||
|
and line.strip().endswith(": ...")
|
||||||
|
and ") -> " not in line
|
||||||
|
):
|
||||||
|
# Add -> None for methods without return annotation
|
||||||
|
line = line.replace(": ...", " -> None: ...")
|
||||||
|
|
||||||
|
processed.append(line)
|
||||||
|
|
||||||
|
return processed
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
|
||||||
|
"""
|
||||||
|
Generate a .pyi stub file for the sync class to help IDEs with type checking.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Only generate stub if we can determine module path
|
||||||
|
if async_class.__module__ == "__main__":
|
||||||
|
return
|
||||||
|
|
||||||
|
module = inspect.getmodule(async_class)
|
||||||
|
if not module:
|
||||||
|
return
|
||||||
|
|
||||||
|
module_path = module.__file__
|
||||||
|
if not module_path:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create stub file path in a 'generated' subdirectory
|
||||||
|
module_dir = os.path.dirname(module_path)
|
||||||
|
stub_dir = os.path.join(module_dir, "generated")
|
||||||
|
|
||||||
|
# Ensure the generated directory exists
|
||||||
|
os.makedirs(stub_dir, exist_ok=True)
|
||||||
|
|
||||||
|
module_name = os.path.basename(module_path)
|
||||||
|
if module_name.endswith(".py"):
|
||||||
|
module_name = module_name[:-3]
|
||||||
|
|
||||||
|
sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi")
|
||||||
|
|
||||||
|
# Create a type tracker for this stub generation
|
||||||
|
type_tracker = TypeTracker()
|
||||||
|
|
||||||
|
stub_content = []
|
||||||
|
|
||||||
|
# We'll generate imports after processing all methods to capture all types
|
||||||
|
# Leave a placeholder for imports
|
||||||
|
imports_placeholder_index = len(stub_content)
|
||||||
|
stub_content.append("") # Will be replaced with imports later
|
||||||
|
|
||||||
|
# Class definition
|
||||||
|
stub_content.append(f"class {sync_class.__name__}:")
|
||||||
|
|
||||||
|
# Docstring
|
||||||
|
if async_class.__doc__:
|
||||||
|
stub_content.extend(
|
||||||
|
cls._format_docstring_for_stub(async_class.__doc__, " ")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate __init__
|
||||||
|
try:
|
||||||
|
init_method = async_class.__init__
|
||||||
|
init_signature = inspect.signature(init_method)
|
||||||
|
# Format parameters
|
||||||
|
params_str = cls._format_method_parameters(
|
||||||
|
init_signature, type_tracker=type_tracker
|
||||||
|
)
|
||||||
|
# Add __init__ docstring if available (before the method)
|
||||||
|
if hasattr(init_method, "__doc__") and init_method.__doc__:
|
||||||
|
stub_content.extend(
|
||||||
|
cls._format_docstring_for_stub(init_method.__doc__, " ")
|
||||||
|
)
|
||||||
|
stub_content.append(f" def __init__({params_str}) -> None: ...")
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
stub_content.append(
|
||||||
|
" def __init__(self, *args, **kwargs) -> None: ..."
|
||||||
|
)
|
||||||
|
|
||||||
|
stub_content.append("") # Add newline after __init__
|
||||||
|
|
||||||
|
# Get class attributes
|
||||||
|
class_attributes = cls._get_class_attributes(async_class)
|
||||||
|
|
||||||
|
# Generate inner classes
|
||||||
|
for name, attr in class_attributes:
|
||||||
|
inner_class_stub = cls._generate_inner_class_stub(
|
||||||
|
name, attr, type_tracker=type_tracker
|
||||||
|
)
|
||||||
|
stub_content.extend(inner_class_stub)
|
||||||
|
stub_content.append("") # Add newline after the inner class
|
||||||
|
|
||||||
|
# Add methods to the main class
|
||||||
|
processed_methods = set() # Keep track of methods we've processed
|
||||||
|
for name, method in sorted(
|
||||||
|
inspect.getmembers(async_class, predicate=inspect.isfunction)
|
||||||
|
):
|
||||||
|
if name.startswith("_") or name in processed_methods:
|
||||||
|
continue
|
||||||
|
|
||||||
|
processed_methods.add(name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
method_sig = cls._generate_method_signature(
|
||||||
|
name, method, is_async=True, type_tracker=type_tracker
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add docstring if available (before the method signature for proper formatting)
|
||||||
|
if method.__doc__:
|
||||||
|
stub_content.extend(
|
||||||
|
cls._format_docstring_for_stub(method.__doc__, " ")
|
||||||
|
)
|
||||||
|
|
||||||
|
stub_content.append(f" {method_sig}")
|
||||||
|
|
||||||
|
stub_content.append("") # Add newline after each method
|
||||||
|
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# If we can't get the signature, just add a simple stub
|
||||||
|
stub_content.append(f" def {name}(self, *args, **kwargs): ...")
|
||||||
|
stub_content.append("") # Add newline
|
||||||
|
|
||||||
|
# Add properties
|
||||||
|
for name, prop in sorted(
|
||||||
|
inspect.getmembers(async_class, lambda x: isinstance(x, property))
|
||||||
|
):
|
||||||
|
stub_content.append(" @property")
|
||||||
|
stub_content.append(f" def {name}(self) -> Any: ...")
|
||||||
|
if prop.fset:
|
||||||
|
stub_content.append(f" @{name}.setter")
|
||||||
|
stub_content.append(
|
||||||
|
f" def {name}(self, value: Any) -> None: ..."
|
||||||
|
)
|
||||||
|
stub_content.append("") # Add newline after each property
|
||||||
|
|
||||||
|
# Add placeholders for the nested class instances
|
||||||
|
# Check the actual attribute names from class annotations and attributes
|
||||||
|
attribute_mappings = {}
|
||||||
|
|
||||||
|
# First check annotations for typed attributes (including from parent classes)
|
||||||
|
# Collect all annotations from the class hierarchy
|
||||||
|
all_annotations = {}
|
||||||
|
for base_class in reversed(inspect.getmro(async_class)):
|
||||||
|
if hasattr(base_class, "__annotations__"):
|
||||||
|
all_annotations.update(base_class.__annotations__)
|
||||||
|
|
||||||
|
for attr_name, attr_type in sorted(all_annotations.items()):
|
||||||
|
for class_name, class_type in class_attributes:
|
||||||
|
# If the class type matches the annotated type
|
||||||
|
if attr_type == class_type or (
|
||||||
|
hasattr(attr_type, "__name__")
|
||||||
|
and attr_type.__name__ == class_name
|
||||||
|
):
|
||||||
|
attribute_mappings[class_name] = attr_name
|
||||||
|
|
||||||
|
# Remove the extra checking - annotations should be sufficient
|
||||||
|
|
||||||
|
# Add the attribute declarations with proper names
|
||||||
|
for class_name, _ in class_attributes:
|
||||||
|
# Use the attribute name if found in mappings, otherwise use class name
|
||||||
|
attr_name = attribute_mappings.get(class_name, class_name)
|
||||||
|
stub_content.append(f" {attr_name}: {class_name}Sync")
|
||||||
|
|
||||||
|
stub_content.append("") # Add a final newline
|
||||||
|
|
||||||
|
# Now generate imports with all discovered types
|
||||||
|
imports = cls._generate_imports(async_class, type_tracker)
|
||||||
|
|
||||||
|
# Deduplicate imports while preserving order
|
||||||
|
seen = set()
|
||||||
|
unique_imports = []
|
||||||
|
for imp in imports:
|
||||||
|
if imp not in seen:
|
||||||
|
seen.add(imp)
|
||||||
|
unique_imports.append(imp)
|
||||||
|
else:
|
||||||
|
logging.warning(f"Duplicate import detected: {imp}")
|
||||||
|
|
||||||
|
# Replace the placeholder with actual imports
|
||||||
|
stub_content[imports_placeholder_index : imports_placeholder_index + 1] = (
|
||||||
|
unique_imports
|
||||||
|
)
|
||||||
|
|
||||||
|
# Post-process stub content
|
||||||
|
stub_content = cls._post_process_stub_content(stub_content)
|
||||||
|
|
||||||
|
# Write stub file
|
||||||
|
with open(sync_stub_path, "w") as f:
|
||||||
|
f.write("\n".join(stub_content))
|
||||||
|
|
||||||
|
logging.info(f"Generated stub file: {sync_stub_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# If stub generation fails, log the error but don't break the main functionality
|
||||||
|
logging.error(
|
||||||
|
f"Error generating stub file for {sync_class.__name__}: {str(e)}"
|
||||||
|
)
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
|
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
|
||||||
|
"""
|
||||||
|
Creates a sync version of an async class
|
||||||
|
|
||||||
|
Args:
|
||||||
|
async_class: The async class to convert
|
||||||
|
thread_pool_size: Size of thread pool to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new class with sync versions of all async methods
|
||||||
|
"""
|
||||||
|
return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size)
|
33
comfy_api/internal/singleton.py
Normal file
33
comfy_api/internal/singleton.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from typing import Type, TypeVar
|
||||||
|
|
||||||
|
class SingletonMetaclass(type):
|
||||||
|
T = TypeVar("T", bound="SingletonMetaclass")
|
||||||
|
_instances = {}
|
||||||
|
|
||||||
|
def __call__(cls, *args, **kwargs):
|
||||||
|
if cls not in cls._instances:
|
||||||
|
cls._instances[cls] = super(SingletonMetaclass, cls).__call__(
|
||||||
|
*args, **kwargs
|
||||||
|
)
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
def inject_instance(cls: Type[T], instance: T) -> None:
|
||||||
|
assert cls not in SingletonMetaclass._instances, (
|
||||||
|
"Cannot inject instance after first instantiation"
|
||||||
|
)
|
||||||
|
SingletonMetaclass._instances[cls] = instance
|
||||||
|
|
||||||
|
def get_instance(cls: Type[T], *args, **kwargs) -> T:
|
||||||
|
"""
|
||||||
|
Gets the singleton instance of the class, creating it if it doesn't exist.
|
||||||
|
"""
|
||||||
|
if cls not in SingletonMetaclass._instances:
|
||||||
|
SingletonMetaclass._instances[cls] = super(
|
||||||
|
SingletonMetaclass, cls
|
||||||
|
).__call__(*args, **kwargs)
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
|
||||||
|
class ProxiedSingleton(object, metaclass=SingletonMetaclass):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
84
comfy_api/latest/__init__.py
Normal file
84
comfy_api/latest/__init__.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Type, TYPE_CHECKING
|
||||||
|
from comfy_api.internal import ComfyAPIBase
|
||||||
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
|
from comfy_api.internal.async_to_sync import create_sync_class
|
||||||
|
from comfy_api.latest.input import ImageInput
|
||||||
|
from comfy_api.latest._io import _IO as io #noqa: F401
|
||||||
|
from comfy_api.latest._ui import _UI as ui #noqa: F401
|
||||||
|
from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||||
|
from comfy_execution.utils import get_executing_context
|
||||||
|
from comfy_execution.progress import get_progress_state
|
||||||
|
from PIL import Image
|
||||||
|
from comfy.cli_args import args
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyAPI_latest(ComfyAPIBase):
|
||||||
|
VERSION = "latest"
|
||||||
|
STABLE = False
|
||||||
|
|
||||||
|
class Execution(ProxiedSingleton):
|
||||||
|
async def set_progress(
|
||||||
|
self,
|
||||||
|
value: float,
|
||||||
|
max_value: float,
|
||||||
|
node_id: str | None = None,
|
||||||
|
preview_image: Image.Image | ImageInput | None = None,
|
||||||
|
ignore_size_limit: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update the progress bar displayed in the ComfyUI interface.
|
||||||
|
|
||||||
|
This function allows custom nodes and API calls to report their progress
|
||||||
|
back to the user interface, providing visual feedback during long operations.
|
||||||
|
|
||||||
|
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||||
|
"""
|
||||||
|
executing_context = get_executing_context()
|
||||||
|
if node_id is None and executing_context is not None:
|
||||||
|
node_id = executing_context.node_id
|
||||||
|
if node_id is None:
|
||||||
|
raise ValueError("node_id must be provided if not in executing context")
|
||||||
|
|
||||||
|
# Convert preview_image to PreviewImageTuple if needed
|
||||||
|
if preview_image is not None:
|
||||||
|
# First convert to PIL Image if needed
|
||||||
|
if isinstance(preview_image, ImageInput):
|
||||||
|
# Convert ImageInput (torch.Tensor) to PIL Image
|
||||||
|
# Handle tensor shape [B, H, W, C] -> get first image if batch
|
||||||
|
tensor = preview_image
|
||||||
|
if len(tensor.shape) == 4:
|
||||||
|
tensor = tensor[0]
|
||||||
|
|
||||||
|
# Convert to numpy array and scale to 0-255
|
||||||
|
image_np = (tensor.cpu().numpy() * 255).astype(np.uint8)
|
||||||
|
preview_image = Image.fromarray(image_np)
|
||||||
|
|
||||||
|
if isinstance(preview_image, Image.Image):
|
||||||
|
# Detect image format from PIL Image
|
||||||
|
image_format = preview_image.format if preview_image.format else "JPEG"
|
||||||
|
# Use None for preview_size if ignore_size_limit is True
|
||||||
|
preview_size = None if ignore_size_limit else args.preview_size
|
||||||
|
preview_image = (image_format, preview_image, preview_size)
|
||||||
|
|
||||||
|
get_progress_state().update_progress(
|
||||||
|
node_id=node_id,
|
||||||
|
value=value,
|
||||||
|
max_value=max_value,
|
||||||
|
image=preview_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
execution: Execution
|
||||||
|
|
||||||
|
|
||||||
|
ComfyAPI = ComfyAPI_latest
|
||||||
|
|
||||||
|
# Create a synchronous version of the API
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
|
||||||
|
|
||||||
|
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||||
|
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
||||||
|
|
@ -24,7 +24,7 @@ from comfy.sd import StyleModel as StyleModel_
|
|||||||
from comfy_api.input import VideoInput
|
from comfy_api.input import VideoInput
|
||||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||||
prune_dict, shallow_clone_class)
|
prune_dict, shallow_clone_class)
|
||||||
from comfy_api.v3._resources import Resources, ResourcesLocal
|
from comfy_api.latest._resources import Resources, ResourcesLocal
|
||||||
from comfy_execution.graph import ExecutionBlocker
|
from comfy_execution.graph import ExecutionBlocker
|
||||||
|
|
||||||
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
|
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
|
@ -17,7 +17,7 @@ import folder_paths
|
|||||||
|
|
||||||
# used for image preview
|
# used for image preview
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy_api.v3._io import ComfyNode, FolderType, Image, _UIOutput
|
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
|
||||||
|
|
||||||
|
|
||||||
class SavedResult(dict):
|
class SavedResult(dict):
|
20
comfy_api/latest/generated/ComfyAPISyncStub.pyi
Normal file
20
comfy_api/latest/generated/ComfyAPISyncStub.pyi
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
|
||||||
|
from comfy_api.latest import ComfyAPI_latest
|
||||||
|
from PIL.Image import Image
|
||||||
|
from torch import Tensor
|
||||||
|
class ComfyAPISyncStub:
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
|
||||||
|
class ExecutionSync:
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
"""
|
||||||
|
Update the progress bar displayed in the ComfyUI interface.
|
||||||
|
|
||||||
|
This function allows custom nodes and API calls to report their progress
|
||||||
|
back to the user interface, providing visual feedback during long operations.
|
||||||
|
|
||||||
|
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||||
|
"""
|
||||||
|
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[tuple[str, Image, Union[int, None]], Image, Tensor, None] = None) -> None: ...
|
||||||
|
|
||||||
|
execution: ExecutionSync
|
10
comfy_api/latest/input/__init__.py
Normal file
10
comfy_api/latest/input/__init__.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||||
|
from .video_types import VideoInput
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ImageInput",
|
||||||
|
"AudioInput",
|
||||||
|
"VideoInput",
|
||||||
|
"MaskInput",
|
||||||
|
"LatentInput",
|
||||||
|
]
|
42
comfy_api/latest/input/basic_types.py
Normal file
42
comfy_api/latest/input/basic_types.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
from typing import TypedDict, List, Optional
|
||||||
|
|
||||||
|
ImageInput = torch.Tensor
|
||||||
|
"""
|
||||||
|
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
|
||||||
|
"""
|
||||||
|
|
||||||
|
MaskInput = torch.Tensor
|
||||||
|
"""
|
||||||
|
A mask in format [B, H, W] where B is the batch size
|
||||||
|
"""
|
||||||
|
|
||||||
|
class AudioInput(TypedDict):
|
||||||
|
"""
|
||||||
|
TypedDict representing audio input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
waveform: torch.Tensor
|
||||||
|
"""
|
||||||
|
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
|
||||||
|
"""
|
||||||
|
|
||||||
|
sample_rate: int
|
||||||
|
|
||||||
|
class LatentInput(TypedDict):
|
||||||
|
"""
|
||||||
|
TypedDict representing latent input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
samples: torch.Tensor
|
||||||
|
"""
|
||||||
|
Tensor in the format [B, C, H, W] where B is the batch size, C is the number of channels,
|
||||||
|
H is the height, and W is the width.
|
||||||
|
"""
|
||||||
|
|
||||||
|
noise_mask: Optional[MaskInput]
|
||||||
|
"""
|
||||||
|
Optional noise mask tensor in the same format as samples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_index: Optional[List[int]]
|
72
comfy_api/latest/input/video_types.py
Normal file
72
comfy_api/latest/input/video_types.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, Union
|
||||||
|
import io
|
||||||
|
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
|
||||||
|
class VideoInput(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for video input types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_components(self) -> VideoComponents:
|
||||||
|
"""
|
||||||
|
Abstract method to get the video components (images, audio, and frame rate).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VideoComponents containing images, audio, and frame rate
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_to(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Abstract method to save the video input to a file.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||||
|
"""
|
||||||
|
Get a streamable source for the video. This allows processing without
|
||||||
|
loading the entire video into memory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either a file path (str) or a BytesIO object that can be opened with av.
|
||||||
|
|
||||||
|
Default implementation creates a BytesIO buffer, but subclasses should
|
||||||
|
override this for better performance when possible.
|
||||||
|
"""
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
self.save_to(buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
# Provide a default implementation, but subclasses can provide optimized versions
|
||||||
|
# if possible.
|
||||||
|
def get_dimensions(self) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Returns the dimensions of the video input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (width, height)
|
||||||
|
"""
|
||||||
|
components = self.get_components()
|
||||||
|
return components.images.shape[2], components.images.shape[1]
|
||||||
|
|
||||||
|
def get_duration(self) -> float:
|
||||||
|
"""
|
||||||
|
Returns the duration of the video in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Duration in seconds
|
||||||
|
"""
|
||||||
|
components = self.get_components()
|
||||||
|
frame_count = components.images.shape[0]
|
||||||
|
return float(frame_count / components.frame_rate)
|
7
comfy_api/latest/input_impl/__init__.py
Normal file
7
comfy_api/latest/input_impl/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from .video_types import VideoFromFile, VideoFromComponents
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Implementations
|
||||||
|
"VideoFromFile",
|
||||||
|
"VideoFromComponents",
|
||||||
|
]
|
312
comfy_api/latest/input_impl/video_types.py
Normal file
312
comfy_api/latest/input_impl/video_types.py
Normal file
@ -0,0 +1,312 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from av.container import InputContainer
|
||||||
|
from av.subtitles.stream import SubtitleStream
|
||||||
|
from fractions import Fraction
|
||||||
|
from typing import Optional
|
||||||
|
from comfy_api.latest.input import AudioInput, VideoInput
|
||||||
|
import av
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from comfy_api.latest.util import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
|
||||||
|
|
||||||
|
def container_to_output_format(container_format: str | None) -> str | None:
|
||||||
|
"""
|
||||||
|
A container's `format` may be a comma-separated list of formats.
|
||||||
|
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
|
||||||
|
However, writing to a file/stream with `av.open` requires a single format,
|
||||||
|
or `None` to auto-detect.
|
||||||
|
"""
|
||||||
|
if not container_format:
|
||||||
|
return None # Auto-detect
|
||||||
|
|
||||||
|
if "," not in container_format:
|
||||||
|
return container_format
|
||||||
|
|
||||||
|
formats = container_format.split(",")
|
||||||
|
return formats[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_open_write_kwargs(
|
||||||
|
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||||
|
) -> dict:
|
||||||
|
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
|
||||||
|
open_kwargs = {
|
||||||
|
"mode": "w",
|
||||||
|
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
|
||||||
|
"options": {"movflags": "use_metadata_tags"},
|
||||||
|
}
|
||||||
|
|
||||||
|
is_write_to_buffer = isinstance(dest, io.BytesIO)
|
||||||
|
if is_write_to_buffer:
|
||||||
|
# Set output format explicitly, since it cannot be inferred from file extension
|
||||||
|
if to_format == VideoContainer.AUTO:
|
||||||
|
to_format = container_format.lower()
|
||||||
|
elif isinstance(to_format, str):
|
||||||
|
to_format = to_format.lower()
|
||||||
|
open_kwargs["format"] = container_to_output_format(to_format)
|
||||||
|
|
||||||
|
return open_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class VideoFromFile(VideoInput):
|
||||||
|
"""
|
||||||
|
Class representing video input from a file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, file: str | io.BytesIO):
|
||||||
|
"""
|
||||||
|
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||||
|
containing the file contents.
|
||||||
|
"""
|
||||||
|
self.__file = file
|
||||||
|
|
||||||
|
def get_stream_source(self) -> str | io.BytesIO:
|
||||||
|
"""
|
||||||
|
Return the underlying file source for efficient streaming.
|
||||||
|
This avoids unnecessary memory copies when the source is already a file path.
|
||||||
|
"""
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0)
|
||||||
|
return self.__file
|
||||||
|
|
||||||
|
def get_dimensions(self) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Returns the dimensions of the video input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (width, height)
|
||||||
|
"""
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
|
with av.open(self.__file, mode='r') as container:
|
||||||
|
for stream in container.streams:
|
||||||
|
if stream.type == 'video':
|
||||||
|
assert isinstance(stream, av.VideoStream)
|
||||||
|
return stream.width, stream.height
|
||||||
|
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||||
|
|
||||||
|
def get_duration(self) -> float:
|
||||||
|
"""
|
||||||
|
Returns the duration of the video in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Duration in seconds
|
||||||
|
"""
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0)
|
||||||
|
with av.open(self.__file, mode="r") as container:
|
||||||
|
if container.duration is not None:
|
||||||
|
return float(container.duration / av.time_base)
|
||||||
|
|
||||||
|
# Fallback: calculate from frame count and frame rate
|
||||||
|
video_stream = next(
|
||||||
|
(s for s in container.streams if s.type == "video"), None
|
||||||
|
)
|
||||||
|
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||||
|
return float(video_stream.frames / video_stream.average_rate)
|
||||||
|
|
||||||
|
# Last resort: decode frames to count them
|
||||||
|
if video_stream and video_stream.average_rate:
|
||||||
|
frame_count = 0
|
||||||
|
container.seek(0)
|
||||||
|
for packet in container.demux(video_stream):
|
||||||
|
for _ in packet.decode():
|
||||||
|
frame_count += 1
|
||||||
|
if frame_count > 0:
|
||||||
|
return float(frame_count / video_stream.average_rate)
|
||||||
|
|
||||||
|
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||||
|
|
||||||
|
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||||
|
# Get video frames
|
||||||
|
frames = []
|
||||||
|
for frame in container.decode(video=0):
|
||||||
|
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||||
|
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||||
|
frames.append(img)
|
||||||
|
|
||||||
|
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||||
|
|
||||||
|
# Get frame rate
|
||||||
|
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||||
|
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||||
|
|
||||||
|
# Get audio if available
|
||||||
|
audio = None
|
||||||
|
try:
|
||||||
|
container.seek(0) # Reset the container to the beginning
|
||||||
|
for stream in container.streams:
|
||||||
|
if stream.type != 'audio':
|
||||||
|
continue
|
||||||
|
assert isinstance(stream, av.AudioStream)
|
||||||
|
audio_frames = []
|
||||||
|
for packet in container.demux(stream):
|
||||||
|
for frame in packet.decode():
|
||||||
|
assert isinstance(frame, av.AudioFrame)
|
||||||
|
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||||
|
if len(audio_frames) > 0:
|
||||||
|
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||||
|
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||||
|
audio = AudioInput({
|
||||||
|
"waveform": audio_tensor,
|
||||||
|
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||||
|
})
|
||||||
|
except StopIteration:
|
||||||
|
pass # No audio stream
|
||||||
|
|
||||||
|
metadata = container.metadata
|
||||||
|
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||||
|
|
||||||
|
def get_components(self) -> VideoComponents:
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
|
with av.open(self.__file, mode='r') as container:
|
||||||
|
return self.get_components_internal(container)
|
||||||
|
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||||
|
|
||||||
|
def save_to(
|
||||||
|
self,
|
||||||
|
path: str | io.BytesIO,
|
||||||
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
):
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
|
with av.open(self.__file, mode='r') as container:
|
||||||
|
container_format = container.format.name
|
||||||
|
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||||
|
reuse_streams = True
|
||||||
|
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||||
|
reuse_streams = False
|
||||||
|
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||||
|
reuse_streams = False
|
||||||
|
|
||||||
|
if not reuse_streams:
|
||||||
|
components = self.get_components_internal(container)
|
||||||
|
video = VideoFromComponents(components)
|
||||||
|
return video.save_to(
|
||||||
|
path,
|
||||||
|
format=format,
|
||||||
|
codec=codec,
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
streams = container.streams
|
||||||
|
|
||||||
|
open_kwargs = get_open_write_kwargs(path, container_format, format)
|
||||||
|
with av.open(path, **open_kwargs) as output_container:
|
||||||
|
# Copy over the original metadata
|
||||||
|
for key, value in container.metadata.items():
|
||||||
|
if metadata is None or key not in metadata:
|
||||||
|
output_container.metadata[key] = value
|
||||||
|
|
||||||
|
# Add our new metadata
|
||||||
|
if metadata is not None:
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
output_container.metadata[key] = value
|
||||||
|
else:
|
||||||
|
output_container.metadata[key] = json.dumps(value)
|
||||||
|
|
||||||
|
# Add streams to the new container
|
||||||
|
stream_map = {}
|
||||||
|
for stream in streams:
|
||||||
|
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
|
||||||
|
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
|
||||||
|
stream_map[stream] = out_stream
|
||||||
|
|
||||||
|
# Write packets to the new container
|
||||||
|
for packet in container.demux():
|
||||||
|
if packet.stream in stream_map and packet.dts is not None:
|
||||||
|
packet.stream = stream_map[packet.stream]
|
||||||
|
output_container.mux(packet)
|
||||||
|
|
||||||
|
class VideoFromComponents(VideoInput):
|
||||||
|
"""
|
||||||
|
Class representing video input from tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, components: VideoComponents):
|
||||||
|
self.__components = components
|
||||||
|
|
||||||
|
def get_components(self) -> VideoComponents:
|
||||||
|
return VideoComponents(
|
||||||
|
images=self.__components.images,
|
||||||
|
audio=self.__components.audio,
|
||||||
|
frame_rate=self.__components.frame_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_to(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
):
|
||||||
|
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||||
|
raise ValueError("Only MP4 format is supported for now")
|
||||||
|
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||||
|
raise ValueError("Only H264 codec is supported for now")
|
||||||
|
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
||||||
|
# Add metadata before writing any streams
|
||||||
|
if metadata is not None:
|
||||||
|
for key, value in metadata.items():
|
||||||
|
output.metadata[key] = json.dumps(value)
|
||||||
|
|
||||||
|
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||||
|
# Create a video stream
|
||||||
|
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||||
|
video_stream.width = self.__components.images.shape[2]
|
||||||
|
video_stream.height = self.__components.images.shape[1]
|
||||||
|
video_stream.pix_fmt = 'yuv420p'
|
||||||
|
|
||||||
|
# Create an audio stream
|
||||||
|
audio_sample_rate = 1
|
||||||
|
audio_stream: Optional[av.AudioStream] = None
|
||||||
|
if self.__components.audio:
|
||||||
|
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||||
|
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||||
|
audio_stream.sample_rate = audio_sample_rate
|
||||||
|
audio_stream.format = 'fltp'
|
||||||
|
|
||||||
|
# Encode video
|
||||||
|
for i, frame in enumerate(self.__components.images):
|
||||||
|
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||||
|
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||||
|
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||||
|
packet = video_stream.encode(frame)
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
# Flush video
|
||||||
|
packet = video_stream.encode(None)
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
if audio_stream and self.__components.audio:
|
||||||
|
# Encode audio
|
||||||
|
samples_per_frame = int(audio_sample_rate / frame_rate)
|
||||||
|
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
|
||||||
|
for i in range(num_frames):
|
||||||
|
start = i * samples_per_frame
|
||||||
|
end = start + samples_per_frame
|
||||||
|
# TODO(Feature) - Add support for stereo audio
|
||||||
|
chunk = (
|
||||||
|
self.__components.audio["waveform"][0, 0, start:end]
|
||||||
|
.unsqueeze(0)
|
||||||
|
.contiguous()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
||||||
|
audio_frame.sample_rate = audio_sample_rate
|
||||||
|
audio_frame.pts = i * samples_per_frame
|
||||||
|
for packet in audio_stream.encode(audio_frame):
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
# Flush audio
|
||||||
|
for packet in audio_stream.encode(None):
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
|
8
comfy_api/latest/util/__init__.py
Normal file
8
comfy_api/latest/util/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Utility Types
|
||||||
|
"VideoContainer",
|
||||||
|
"VideoCodec",
|
||||||
|
"VideoComponents",
|
||||||
|
]
|
52
comfy_api/latest/util/video_types.py
Normal file
52
comfy_api/latest/util/video_types.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from fractions import Fraction
|
||||||
|
from typing import Optional
|
||||||
|
from comfy_api.latest.input import ImageInput, AudioInput
|
||||||
|
|
||||||
|
class VideoCodec(str, Enum):
|
||||||
|
AUTO = "auto"
|
||||||
|
H264 = "h264"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def as_input(cls) -> list[str]:
|
||||||
|
"""
|
||||||
|
Returns a list of codec names that can be used as node input.
|
||||||
|
"""
|
||||||
|
return [member.value for member in cls]
|
||||||
|
|
||||||
|
class VideoContainer(str, Enum):
|
||||||
|
AUTO = "auto"
|
||||||
|
MP4 = "mp4"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def as_input(cls) -> list[str]:
|
||||||
|
"""
|
||||||
|
Returns a list of container names that can be used as node input.
|
||||||
|
"""
|
||||||
|
return [member.value for member in cls]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_extension(cls, value) -> str:
|
||||||
|
"""
|
||||||
|
Returns the file extension for the container.
|
||||||
|
"""
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = cls(value)
|
||||||
|
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
|
||||||
|
return "mp4"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoComponents:
|
||||||
|
"""
|
||||||
|
Dataclass representing the components of a video.
|
||||||
|
"""
|
||||||
|
|
||||||
|
images: ImageInput
|
||||||
|
frame_rate: Fraction
|
||||||
|
audio: Optional[AudioInput] = None
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
|
||||||
|
|
2
comfy_api/util.py
Normal file
2
comfy_api/util.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# This file only exists for backwards compatibility.
|
||||||
|
from comfy_api.latest.util import * # noqa: F403
|
@ -1,8 +1,2 @@
|
|||||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
# This file only exists for backwards compatibility.
|
||||||
|
from comfy_api.latest.util import * # noqa: F403
|
||||||
__all__ = [
|
|
||||||
# Utility Types
|
|
||||||
"VideoContainer",
|
|
||||||
"VideoCodec",
|
|
||||||
"VideoComponents",
|
|
||||||
]
|
|
||||||
|
@ -1,51 +1,2 @@
|
|||||||
from __future__ import annotations
|
# This file only exists for backwards compatibility.
|
||||||
from dataclasses import dataclass
|
from comfy_api.latest.util.video_types import * # noqa: F403
|
||||||
from enum import Enum
|
|
||||||
from fractions import Fraction
|
|
||||||
from typing import Optional
|
|
||||||
from comfy_api.input import ImageInput, AudioInput
|
|
||||||
|
|
||||||
class VideoCodec(str, Enum):
|
|
||||||
AUTO = "auto"
|
|
||||||
H264 = "h264"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def as_input(cls) -> list[str]:
|
|
||||||
"""
|
|
||||||
Returns a list of codec names that can be used as node input.
|
|
||||||
"""
|
|
||||||
return [member.value for member in cls]
|
|
||||||
|
|
||||||
class VideoContainer(str, Enum):
|
|
||||||
AUTO = "auto"
|
|
||||||
MP4 = "mp4"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def as_input(cls) -> list[str]:
|
|
||||||
"""
|
|
||||||
Returns a list of container names that can be used as node input.
|
|
||||||
"""
|
|
||||||
return [member.value for member in cls]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_extension(cls, value) -> str:
|
|
||||||
"""
|
|
||||||
Returns the file extension for the container.
|
|
||||||
"""
|
|
||||||
if isinstance(value, str):
|
|
||||||
value = cls(value)
|
|
||||||
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
|
|
||||||
return "mp4"
|
|
||||||
return ""
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class VideoComponents:
|
|
||||||
"""
|
|
||||||
Dataclass representing the components of a video.
|
|
||||||
"""
|
|
||||||
|
|
||||||
images: ImageInput
|
|
||||||
frame_rate: Fraction
|
|
||||||
audio: Optional[AudioInput] = None
|
|
||||||
metadata: Optional[dict] = None
|
|
||||||
|
|
||||||
|
18
comfy_api/v0_0_1/__init__.py
Normal file
18
comfy_api/v0_0_1/__init__.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
||||||
|
from typing import Type, TYPE_CHECKING
|
||||||
|
from comfy_api.internal.async_to_sync import create_sync_class
|
||||||
|
|
||||||
|
# This version only exists to serve as a template for future version adapters.
|
||||||
|
# There is no reason anyone should ever use it.
|
||||||
|
class ComfyAPIAdapter_v0_0_1(ComfyAPIAdapter_v0_0_2):
|
||||||
|
VERSION = "0.0.1"
|
||||||
|
STABLE = True
|
||||||
|
|
||||||
|
ComfyAPI = ComfyAPIAdapter_v0_0_1
|
||||||
|
|
||||||
|
# Create a synchronous version of the API
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy_api.v0_0_1.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore
|
||||||
|
ComfyAPISync: Type[ComfyAPISyncStub]
|
||||||
|
|
||||||
|
ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_1)
|
20
comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi
Normal file
20
comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
|
||||||
|
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
||||||
|
from PIL.Image import Image
|
||||||
|
from torch import Tensor
|
||||||
|
class ComfyAPISyncStub:
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
|
||||||
|
class ExecutionSync:
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
"""
|
||||||
|
Update the progress bar displayed in the ComfyUI interface.
|
||||||
|
|
||||||
|
This function allows custom nodes and API calls to report their progress
|
||||||
|
back to the user interface, providing visual feedback during long operations.
|
||||||
|
|
||||||
|
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||||
|
"""
|
||||||
|
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[tuple[str, Image, Union[int, None]], Image, Tensor, None] = None) -> None: ...
|
||||||
|
|
||||||
|
execution: ExecutionSync
|
15
comfy_api/v0_0_2/__init__.py
Normal file
15
comfy_api/v0_0_2/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from comfy_api.latest import ComfyAPI_latest
|
||||||
|
from typing import Type, TYPE_CHECKING
|
||||||
|
from comfy_api.internal.async_to_sync import create_sync_class
|
||||||
|
|
||||||
|
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
||||||
|
VERSION = "0.0.2"
|
||||||
|
STABLE = False
|
||||||
|
|
||||||
|
ComfyAPI = ComfyAPIAdapter_v0_0_2
|
||||||
|
|
||||||
|
# Create a synchronous version of the API
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy_api.v0_0_2.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore
|
||||||
|
ComfyAPISync: Type[ComfyAPISyncStub]
|
||||||
|
ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_2)
|
20
comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi
Normal file
20
comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
|
||||||
|
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
||||||
|
from PIL.Image import Image
|
||||||
|
from torch import Tensor
|
||||||
|
class ComfyAPISyncStub:
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
|
||||||
|
class ExecutionSync:
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
"""
|
||||||
|
Update the progress bar displayed in the ComfyUI interface.
|
||||||
|
|
||||||
|
This function allows custom nodes and API calls to report their progress
|
||||||
|
back to the user interface, providing visual feedback during long operations.
|
||||||
|
|
||||||
|
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||||
|
"""
|
||||||
|
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[tuple[str, Image, Union[int, None]], Image, Tensor, None] = None) -> None: ...
|
||||||
|
|
||||||
|
execution: ExecutionSync
|
@ -1,9 +0,0 @@
|
|||||||
from comfy_api.v3._io import _IO
|
|
||||||
from comfy_api.v3._ui import _UI
|
|
||||||
from comfy_api.v3._resources import _RESOURCES
|
|
||||||
|
|
||||||
io = _IO
|
|
||||||
ui = _UI
|
|
||||||
resources = _RESOURCES
|
|
||||||
|
|
||||||
__all__ = ["io", "ui", "resources"]
|
|
12
comfy_api/version_list.py
Normal file
12
comfy_api/version_list.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from comfy_api.latest import ComfyAPI_latest
|
||||||
|
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
||||||
|
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
||||||
|
from comfy_api.internal import ComfyAPIBase
|
||||||
|
from typing import List, Type
|
||||||
|
|
||||||
|
supported_versions: List[Type[ComfyAPIBase]] = [
|
||||||
|
ComfyAPI_latest,
|
||||||
|
ComfyAPIAdapter_v0_0_2,
|
||||||
|
ComfyAPIAdapter_v0_0_1,
|
||||||
|
]
|
||||||
|
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
API Nodes for Gemini Multimodal LLM Usage via Remote API
|
API Nodes for Gemini Multimodal LLM Usage via Remote API
|
||||||
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
from typing import TypedDict, Dict, Optional
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TypedDict, Dict, Optional, Tuple
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -10,6 +12,7 @@ if TYPE_CHECKING:
|
|||||||
from protocol import BinaryEventTypes
|
from protocol import BinaryEventTypes
|
||||||
from comfy_api import feature_flags
|
from comfy_api import feature_flags
|
||||||
|
|
||||||
|
PreviewImageTuple = Tuple[str, Image.Image, Optional[int]]
|
||||||
|
|
||||||
class NodeState(Enum):
|
class NodeState(Enum):
|
||||||
Pending = "pending"
|
Pending = "pending"
|
||||||
@ -52,7 +55,7 @@ class ProgressHandler(ABC):
|
|||||||
max_value: float,
|
max_value: float,
|
||||||
state: NodeProgressState,
|
state: NodeProgressState,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
image: Optional[Image.Image] = None,
|
image: PreviewImageTuple | None = None,
|
||||||
):
|
):
|
||||||
"""Called when a node's progress is updated"""
|
"""Called when a node's progress is updated"""
|
||||||
pass
|
pass
|
||||||
@ -103,7 +106,7 @@ class CLIProgressHandler(ProgressHandler):
|
|||||||
max_value: float,
|
max_value: float,
|
||||||
state: NodeProgressState,
|
state: NodeProgressState,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
image: Optional[Image.Image] = None,
|
image: PreviewImageTuple | None = None,
|
||||||
):
|
):
|
||||||
# Handle case where start_handler wasn't called
|
# Handle case where start_handler wasn't called
|
||||||
if node_id not in self.progress_bars:
|
if node_id not in self.progress_bars:
|
||||||
@ -196,7 +199,7 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
max_value: float,
|
max_value: float,
|
||||||
state: NodeProgressState,
|
state: NodeProgressState,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
image: Optional[Image.Image] = None,
|
image: PreviewImageTuple | None = None,
|
||||||
):
|
):
|
||||||
# Send progress state of all nodes
|
# Send progress state of all nodes
|
||||||
if self.registry:
|
if self.registry:
|
||||||
@ -231,7 +234,6 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
if self.registry:
|
if self.registry:
|
||||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||||
|
|
||||||
|
|
||||||
class ProgressRegistry:
|
class ProgressRegistry:
|
||||||
"""
|
"""
|
||||||
Registry that maintains node progress state and notifies registered handlers.
|
Registry that maintains node progress state and notifies registered handlers.
|
||||||
@ -285,7 +287,7 @@ class ProgressRegistry:
|
|||||||
handler.start_handler(node_id, entry, self.prompt_id)
|
handler.start_handler(node_id, entry, self.prompt_id)
|
||||||
|
|
||||||
def update_progress(
|
def update_progress(
|
||||||
self, node_id: str, value: float, max_value: float, image: Optional[Image.Image]
|
self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update progress for a node"""
|
"""Update progress for a node"""
|
||||||
entry = self.ensure_entry(node_id)
|
entry = self.ensure_entry(node_id)
|
||||||
@ -317,7 +319,7 @@ class ProgressRegistry:
|
|||||||
handler.reset()
|
handler.reset()
|
||||||
|
|
||||||
# Global registry instance
|
# Global registry instance
|
||||||
global_progress_registry: ProgressRegistry = None
|
global_progress_registry: ProgressRegistry | None = None
|
||||||
|
|
||||||
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
||||||
global global_progress_registry
|
global global_progress_registry
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
from comfy_api.v3 import io, ui, resources, _io
|
from comfy_api.latest import io, ui, resources, _io
|
||||||
import logging # noqa
|
import logging # noqa
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
@ -8,9 +8,9 @@ import json
|
|||||||
from typing import Optional, Literal
|
from typing import Optional, Literal
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
|
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
|
||||||
from comfy_api.input import ImageInput, AudioInput, VideoInput
|
from comfy_api.latest.input import ImageInput, AudioInput, VideoInput
|
||||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
from comfy_api.latest.util import VideoContainer, VideoCodec, VideoComponents
|
||||||
from comfy_api.input_impl import VideoFromFile, VideoFromComponents
|
from comfy_api.latest.input_impl import VideoFromFile, VideoFromComponents
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
class SaveWEBM:
|
class SaveWEBM:
|
||||||
@ -239,3 +239,4 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"GetVideoComponents": "Get Video Components",
|
"GetVideoComponents": "Get Video Components",
|
||||||
"LoadVideo": "Load Video",
|
"LoadVideo": "Load Video",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import torch
|
|||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class TextEncodeAceStepAudio(io.ComfyNode):
|
class TextEncodeAceStepAudio(io.ComfyNode):
|
||||||
|
@ -6,7 +6,7 @@ import comfy.model_patcher
|
|||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from comfy.k_diffusion.sampling import to_d
|
from comfy.k_diffusion.sampling import to_d
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
NOISE_LEVELS = {
|
NOISE_LEVELS = {
|
||||||
"SD1": [
|
"SD1": [
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
def project(v0, v1):
|
def project(v0, v1):
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
def attention_multiply(attn, model, q, k, v, out):
|
def attention_multiply(attn, model, q, k, v, out):
|
||||||
|
@ -9,7 +9,7 @@ import torchaudio
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy_api.v3 import io, ui
|
from comfy_api.latest import io, ui
|
||||||
|
|
||||||
|
|
||||||
class ConditioningStableAudio(io.ComfyNode):
|
class ConditioningStableAudio(io.ComfyNode):
|
||||||
|
@ -6,7 +6,7 @@ from einops import rearrange
|
|||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
CAMERA_DICT = {
|
CAMERA_DICT = {
|
||||||
"base_T_norm": 1.5,
|
"base_T_norm": 1.5,
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from kornia.filters import canny
|
from kornia.filters import canny
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class Canny(io.ComfyNode):
|
class Canny(io.ComfyNode):
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
def optimized_scale(positive, negative):
|
def optimized_scale(positive, negative):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncodeSDXL(io.ComfyNode):
|
class CLIPTextEncodeSDXL(io.ComfyNode):
|
||||||
|
@ -5,7 +5,7 @@ from enum import Enum
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
def resize_mask(mask, shape):
|
def resize_mask(mask, shape):
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncodeControlnet(io.ComfyNode):
|
class CLIPTextEncodeControlnet(io.ComfyNode):
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class ControlNetApplyAdvanced(io.ComfyNode):
|
class ControlNetApplyAdvanced(io.ComfyNode):
|
||||||
|
@ -6,7 +6,7 @@ import comfy.latent_formats
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
def vae_encode_with_padding(vae, image, width, height, length, padding=0):
|
def vae_encode_with_padding(vae, image, width, height, length, padding=0):
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class DifferentialDiffusion(io.ComfyNode):
|
class DifferentialDiffusion(io.ComfyNode):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class ReferenceLatent(io.ComfyNode):
|
class ReferenceLatent(io.ComfyNode):
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
PREFERED_KONTEXT_RESOLUTIONS = [
|
PREFERED_KONTEXT_RESOLUTIONS = [
|
||||||
(672, 1568),
|
(672, 1568),
|
||||||
|
@ -6,7 +6,7 @@ import logging
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
def Fourier_filter(x, threshold, scale):
|
def Fourier_filter(x, threshold, scale):
|
||||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import torch
|
import torch
|
||||||
import torch.fft as fft
|
import torch.fft as fft
|
||||||
|
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
|
def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
def loglinear_interp(t_steps, num_steps):
|
def loglinear_interp(t_steps, num_steps):
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncodeHiDream(io.ComfyNode):
|
class CLIPTextEncodeHiDream(io.ComfyNode):
|
||||||
|
@ -5,7 +5,7 @@ import torch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||||
|
@ -6,7 +6,7 @@ import torch
|
|||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
def load_hypernetwork_patch(path, strength):
|
def load_hypernetwork_patch(path, strength):
|
||||||
|
@ -7,7 +7,7 @@ import math
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import randint
|
from torch import randint
|
||||||
|
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
||||||
|
@ -9,7 +9,7 @@ import comfy.utils
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.v3 import io, ui
|
from comfy_api.latest import io, ui
|
||||||
from server import PromptServer
|
from server import PromptServer
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class InstructPixToPixConditioning(io.ComfyNode):
|
class InstructPixToPixConditioning(io.ComfyNode):
|
||||||
|
@ -4,7 +4,7 @@ import torch
|
|||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy_extras.nodes_post_processing
|
import comfy_extras.nodes_post_processing
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||||
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.input_impl import VideoFromFile
|
from comfy_api.input_impl import VideoFromFile
|
||||||
from comfy_api.v3 import io, ui
|
from comfy_api.latest import io, ui
|
||||||
|
|
||||||
|
|
||||||
def normalize_path(path):
|
def normalize_path(path):
|
||||||
|
@ -9,7 +9,7 @@ import torch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import comfy.model_management as mm
|
import comfy.model_management as mm
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class LotusConditioning(io.ComfyNode):
|
class LotusConditioning(io.ComfyNode):
|
||||||
|
@ -16,7 +16,7 @@ from comfy.ldm.lightricks.symmetric_patchifier import (
|
|||||||
SymmetricPatchifier,
|
SymmetricPatchifier,
|
||||||
latent_to_pixel_coords,
|
latent_to_pixel_coords,
|
||||||
)
|
)
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
def conditioning_get_any_value(conditioning, key, default=None):
|
def conditioning_get_any_value(conditioning, key, default=None):
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncodeLumina2(io.ComfyNode):
|
class CLIPTextEncodeLumina2(io.ComfyNode):
|
||||||
|
@ -7,7 +7,7 @@ import torch
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.v3 import io, ui
|
from comfy_api.latest import io, ui
|
||||||
|
|
||||||
|
|
||||||
def composite(destination, source, x, y, mask=None, multiplier=8, resize_source=False):
|
def composite(destination, source, x, y, mask=None, multiplier=8, resize_source=False):
|
||||||
|
@ -4,7 +4,7 @@ import torch
|
|||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class EmptyMochiLatentVideo(io.ComfyNode):
|
class EmptyMochiLatentVideo(io.ComfyNode):
|
||||||
|
@ -7,7 +7,7 @@ import comfy.model_sampling
|
|||||||
import comfy.sd
|
import comfy.sd
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class LCM(comfy.model_sampling.EPS):
|
class LCM(comfy.model_sampling.EPS):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class PatchModelAddDownscale(io.ComfyNode):
|
class PatchModelAddDownscale(io.ComfyNode):
|
||||||
|
@ -13,7 +13,7 @@ from kornia.morphology import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class ImageRGBToYUV(io.ComfyNode):
|
class ImageRGBToYUV(io.ComfyNode):
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
# from https://github.com/bebebe666/OptimalSteps
|
# from https://github.com/bebebe666/OptimalSteps
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
#Modified/simplified version of the node from: https://github.com/pamparamm/sd-perturbed-attention
|
#Modified/simplified version of the node from: https://github.com/pamparamm/sd-perturbed-attention
|
||||||
#If you want the one with more options see the above repo.
|
#If you want the one with more options see the above repo.
|
||||||
|
@ -9,7 +9,7 @@ import comfy.sampler_helpers
|
|||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale):
|
def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale):
|
||||||
|
@ -9,7 +9,7 @@ import comfy.model_management
|
|||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
# code for model from:
|
# code for model from:
|
||||||
# https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0
|
# https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncodePixArtAlpha(io.ComfyNode):
|
class CLIPTextEncodePixArtAlpha(io.ComfyNode):
|
||||||
|
@ -10,7 +10,7 @@ from PIL import Image
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
def gaussian_kernel(kernel_size: int, sigma: float, device=None):
|
def gaussian_kernel(kernel_size: int, sigma: float, device=None):
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from comfy_api.v3 import io, ui
|
from comfy_api.latest import io, ui
|
||||||
|
|
||||||
|
|
||||||
class PreviewAny(io.ComfyNode):
|
class PreviewAny(io.ComfyNode):
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class String(io.ComfyNode):
|
class String(io.ComfyNode):
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class ImageRebatch(io.ComfyNode):
|
class ImageRebatch(io.ComfyNode):
|
||||||
|
@ -9,7 +9,7 @@ from torch import einsum
|
|||||||
|
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
# from comfy/ldm/modules/attention.py
|
# from comfy/ldm/modules/attention.py
|
||||||
|
@ -6,7 +6,7 @@ import comfy.model_management
|
|||||||
import comfy.sd
|
import comfy.sd
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
from comfy_extras.v3.nodes_slg import SkipLayerGuidanceDiT
|
from comfy_extras.v3.nodes_slg import SkipLayerGuidanceDiT
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class SD_4XUpscale_Conditioning(io.ComfyNode):
|
class SD_4XUpscale_Conditioning(io.ComfyNode):
|
||||||
|
@ -4,7 +4,7 @@ import re
|
|||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class SkipLayerGuidanceDiT(io.ComfyNode):
|
class SkipLayerGuidanceDiT(io.ComfyNode):
|
||||||
|
@ -20,7 +20,7 @@ import torch
|
|||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class StableCascade_EmptyLatentImage(io.ComfyNode):
|
class StableCascade_EmptyLatentImage(io.ComfyNode):
|
||||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
|
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -7,7 +7,7 @@ from typing import Callable, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
def do_nothing(x: torch.Tensor, mode:str=None):
|
def do_nothing(x: torch.Tensor, mode:str=None):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from comfy_api.latest import io
|
||||||
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
||||||
from comfy_api.v3 import io
|
|
||||||
|
|
||||||
|
|
||||||
class TorchCompileModel(io.ComfyNode):
|
class TorchCompileModel(io.ComfyNode):
|
||||||
|
@ -18,7 +18,7 @@ import comfy_extras.nodes_custom_sampler
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy.weight_adapter import adapter_maps, adapters
|
from comfy.weight_adapter import adapter_maps, adapters
|
||||||
from comfy_api.v3 import io, ui
|
from comfy_api.latest import io, ui
|
||||||
|
|
||||||
|
|
||||||
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||||
|
@ -8,7 +8,7 @@ from spandrel import ImageModelDescriptor, ModelLoader
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from spandrel import MAIN_REGISTRY
|
from spandrel import MAIN_REGISTRY
|
||||||
|
@ -11,8 +11,8 @@ import folder_paths
|
|||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy_api.input import AudioInput, ImageInput, VideoInput
|
from comfy_api.input import AudioInput, ImageInput, VideoInput
|
||||||
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
|
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
|
||||||
|
from comfy_api.latest import io, ui
|
||||||
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
|
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
|
||||||
from comfy_api.v3 import io, ui
|
|
||||||
|
|
||||||
|
|
||||||
class CreateVideo(io.ComfyNode):
|
class CreateVideo(io.ComfyNode):
|
||||||
|
@ -8,7 +8,7 @@ import comfy_extras.nodes_model_merging
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class ConditioningSetAreaPercentageVideo(io.ComfyNode):
|
class ConditioningSetAreaPercentageVideo(io.ComfyNode):
|
||||||
|
@ -8,7 +8,7 @@ import comfy.model_management
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class TrimVideoLatent(io.ComfyNode):
|
class TrimVideoLatent(io.ComfyNode):
|
||||||
|
@ -7,7 +7,7 @@ from PIL import Image, ImageOps, ImageSequence
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.v3 import io
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
class WebcamCapture(io.ComfyNode):
|
class WebcamCapture(io.ComfyNode):
|
||||||
|
@ -33,7 +33,7 @@ from comfy_execution.validation import validate_node_input
|
|||||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||||
from comfy_execution.utils import CurrentNodeContext
|
from comfy_execution.utils import CurrentNodeContext
|
||||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||||
from comfy_api.v3 import io, resources
|
from comfy_api.latest import io, resources
|
||||||
|
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
|
10
main.py
10
main.py
@ -22,6 +22,12 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||||
|
|
||||||
|
# Handle --generate-api-stubs early
|
||||||
|
if args.generate_api_stubs:
|
||||||
|
from comfy_api.generate_api_stubs import main as generate_stubs_main
|
||||||
|
generate_stubs_main()
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
def apply_custom_paths():
|
def apply_custom_paths():
|
||||||
# extra model paths
|
# extra model paths
|
||||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||||
@ -313,10 +319,10 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
prompt_server = server.PromptServer(asyncio_loop)
|
prompt_server = server.PromptServer(asyncio_loop)
|
||||||
|
|
||||||
hook_breaker_ac10a0.save_functions()
|
hook_breaker_ac10a0.save_functions()
|
||||||
nodes.init_extra_nodes(
|
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
||||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||||
init_api_nodes=not args.disable_api_nodes
|
init_api_nodes=not args.disable_api_nodes
|
||||||
)
|
))
|
||||||
hook_breaker_ac10a0.restore_functions()
|
hook_breaker_ac10a0.restore_functions()
|
||||||
|
|
||||||
cuda_malloc_warning()
|
cuda_malloc_warning()
|
||||||
|
38
nodes.py
38
nodes.py
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
@ -26,7 +27,9 @@ import comfy.sd
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.controlnet
|
import comfy.controlnet
|
||||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
||||||
from comfy_api.v3 import io
|
from comfy_api.internal import register_versions, ComfyAPIWithVersion
|
||||||
|
from comfy_api.version_list import supported_versions
|
||||||
|
from comfy_api.latest import io
|
||||||
|
|
||||||
import comfy.clip_vision
|
import comfy.clip_vision
|
||||||
|
|
||||||
@ -2102,7 +2105,7 @@ def get_module_name(module_path: str) -> str:
|
|||||||
return base_path
|
return base_path
|
||||||
|
|
||||||
|
|
||||||
def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
|
async def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
|
||||||
module_name = get_module_name(module_path)
|
module_name = get_module_name(module_path)
|
||||||
if os.path.isfile(module_path):
|
if os.path.isfile(module_path):
|
||||||
sp = os.path.splitext(module_path)
|
sp = os.path.splitext(module_path)
|
||||||
@ -2178,7 +2181,7 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
|
|||||||
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
|
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def init_external_custom_nodes():
|
async def init_external_custom_nodes():
|
||||||
"""
|
"""
|
||||||
Initializes the external custom nodes.
|
Initializes the external custom nodes.
|
||||||
|
|
||||||
@ -2204,7 +2207,7 @@ def init_external_custom_nodes():
|
|||||||
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||||
continue
|
continue
|
||||||
time_before = time.perf_counter()
|
time_before = time.perf_counter()
|
||||||
success = load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
|
||||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||||
|
|
||||||
if len(node_import_times) > 0:
|
if len(node_import_times) > 0:
|
||||||
@ -2217,7 +2220,7 @@ def init_external_custom_nodes():
|
|||||||
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
|
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
|
||||||
logging.info("")
|
logging.info("")
|
||||||
|
|
||||||
def init_builtin_extra_nodes():
|
async def init_builtin_extra_nodes():
|
||||||
"""
|
"""
|
||||||
Initializes the built-in extra nodes in ComfyUI.
|
Initializes the built-in extra nodes in ComfyUI.
|
||||||
|
|
||||||
@ -2363,13 +2366,13 @@ def init_builtin_extra_nodes():
|
|||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
for node_file in extras_files:
|
for node_file in extras_files:
|
||||||
if not load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
|
if not await load_custom_node(os.path.join(extras_dir, node_file), module_parent="comfy_extras"):
|
||||||
import_failed.append(node_file)
|
import_failed.append(node_file)
|
||||||
|
|
||||||
return import_failed
|
return import_failed
|
||||||
|
|
||||||
|
|
||||||
def init_builtin_api_nodes():
|
async def init_builtin_api_nodes():
|
||||||
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")
|
api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes")
|
||||||
api_nodes_files = [
|
api_nodes_files = [
|
||||||
"nodes_ideogram.py",
|
"nodes_ideogram.py",
|
||||||
@ -2390,26 +2393,35 @@ def init_builtin_api_nodes():
|
|||||||
"nodes_gemini.py",
|
"nodes_gemini.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
if not load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
|
if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
|
||||||
return api_nodes_files
|
return api_nodes_files
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
for node_file in api_nodes_files:
|
for node_file in api_nodes_files:
|
||||||
if not load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
|
if not await load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"):
|
||||||
import_failed.append(node_file)
|
import_failed.append(node_file)
|
||||||
|
|
||||||
return import_failed
|
return import_failed
|
||||||
|
|
||||||
|
async def init_public_apis():
|
||||||
|
register_versions([
|
||||||
|
ComfyAPIWithVersion(
|
||||||
|
version=getattr(v, "VERSION"),
|
||||||
|
api_class=v
|
||||||
|
) for v in supported_versions
|
||||||
|
])
|
||||||
|
|
||||||
def init_extra_nodes(init_custom_nodes=True, init_api_nodes=True):
|
async def init_extra_nodes(init_custom_nodes=True, init_api_nodes=True):
|
||||||
import_failed = init_builtin_extra_nodes()
|
await init_public_apis()
|
||||||
|
|
||||||
|
import_failed = await init_builtin_extra_nodes()
|
||||||
|
|
||||||
import_failed_api = []
|
import_failed_api = []
|
||||||
if init_api_nodes:
|
if init_api_nodes:
|
||||||
import_failed_api = init_builtin_api_nodes()
|
import_failed_api = await init_builtin_api_nodes()
|
||||||
|
|
||||||
if init_custom_nodes:
|
if init_custom_nodes:
|
||||||
init_external_custom_nodes()
|
await init_external_custom_nodes()
|
||||||
else:
|
else:
|
||||||
logging.info("Skipping loading of custom nodes")
|
logging.info("Skipping loading of custom nodes")
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user