diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 39d1992d7..69ce998eb 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -15,6 +15,14 @@ body: steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen. If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first. + - type: checkboxes + id: custom-nodes-test + attributes: + label: Custom Node Testing + description: Please confirm you have tried to reproduce the issue with all custom nodes disabled. + options: + - label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help) + required: true - type: textarea attributes: label: Expected Behavior diff --git a/.github/ISSUE_TEMPLATE/user-support.yml b/.github/ISSUE_TEMPLATE/user-support.yml index df28804c6..50657d493 100644 --- a/.github/ISSUE_TEMPLATE/user-support.yml +++ b/.github/ISSUE_TEMPLATE/user-support.yml @@ -11,6 +11,14 @@ body: **2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics. If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first. + - type: checkboxes + id: custom-nodes-test + attributes: + label: Custom Node Testing + description: Please confirm you have tried to reproduce the issue with all custom nodes disabled. + options: + - label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help) + required: true - type: textarea attributes: label: Your question diff --git a/CODEOWNERS b/CODEOWNERS index 013ea8622..c4acbf06e 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -5,20 +5,20 @@ # Inlined the team members for now. # Maintainers -*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne # Python web server -/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne -/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne -/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne +/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne +/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne +/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne # Node developers -/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne -/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne +/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne +/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne diff --git a/README.md b/README.md index deee70c6b..0de4a6bb5 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ [![Website][website-shield]][website-url] [![Dynamic JSON Badge][discord-shield]][discord-url] +[![Twitter][twitter-shield]][twitter-url] [![Matrix][matrix-shield]][matrix-url]
[![][github-release-shield]][github-release-link] @@ -20,6 +21,8 @@ [discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total [discord-url]: https://www.comfy.org/discord +[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI +[twitter-url]: https://x.com/ComfyUI [github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver [github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases @@ -62,12 +65,13 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/) - [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/) - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) + - [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/) - Video Models - [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/) - [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/) - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) + - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) - Audio Models - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) @@ -95,7 +99,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) - Latent previews with [TAESD](#how-to-show-high-quality-previews) - Starts up very fast. -- Works fully offline: will never download anything. +- Works fully offline: core will never download anything unless you want to. +- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview). - [Config file](extra_model_paths.yaml.example) to set the search paths for models. Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/) @@ -110,7 +115,6 @@ ComfyUI follows a weekly release cycle every Friday, with three interconnected r 2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)** - Builds a new release using the latest stable core version - - Version numbers match the core release (e.g., Desktop v1.7.0 uses Core v1.7.0) 3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)** - Weekly frontend updates are merged into the core repository @@ -198,11 +202,11 @@ Put your VAE in: models/vae ### AMD GPUs (Linux only) AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: -```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2.4``` +```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3``` -This is the command to install the nightly with ROCm 6.3 which might have some performance improvements: +This is the command to install the nightly with ROCm 6.4 which might have some performance improvements: -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3``` +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4``` ### Intel GPUs (Windows and Linux) @@ -302,7 +306,7 @@ For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 pyt ### AMD ROCm Tips -You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command: +You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default. ```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention``` diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 000000000..12f18712f --- /dev/null +++ b/alembic.ini @@ -0,0 +1,84 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = alembic_db + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic_db/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +# version_path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +version_path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = sqlite:///user/comfyui.db + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME diff --git a/alembic_db/README.md b/alembic_db/README.md new file mode 100644 index 000000000..3b808c7ca --- /dev/null +++ b/alembic_db/README.md @@ -0,0 +1,4 @@ +## Generate new revision + +1. Update models in `/app/database/models.py` +2. Run `alembic revision --autogenerate -m "{your message}"` diff --git a/alembic_db/env.py b/alembic_db/env.py new file mode 100644 index 000000000..4d7770679 --- /dev/null +++ b/alembic_db/env.py @@ -0,0 +1,64 @@ +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + + +from app.database.models import Base +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + Calls to context.execute() here emit the given string to the + script output. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + In this scenario we need to create an Engine + and associate a connection with the context. + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic_db/script.py.mako b/alembic_db/script.py.mako new file mode 100644 index 000000000..480b130d6 --- /dev/null +++ b/alembic_db/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/app/database/db.py b/app/database/db.py new file mode 100644 index 000000000..1de8b80ed --- /dev/null +++ b/app/database/db.py @@ -0,0 +1,112 @@ +import logging +import os +import shutil +from app.logger import log_startup_warning +from utils.install_util import get_missing_requirements_message +from comfy.cli_args import args + +_DB_AVAILABLE = False +Session = None + + +try: + from alembic import command + from alembic.config import Config + from alembic.runtime.migration import MigrationContext + from alembic.script import ScriptDirectory + from sqlalchemy import create_engine + from sqlalchemy.orm import sessionmaker + + _DB_AVAILABLE = True +except ImportError as e: + log_startup_warning( + f""" +------------------------------------------------------------------------ +Error importing dependencies: {e} +{get_missing_requirements_message()} +This error is happening because ComfyUI now uses a local sqlite database. +------------------------------------------------------------------------ +""".strip() + ) + + +def dependencies_available(): + """ + Temporary function to check if the dependencies are available + """ + return _DB_AVAILABLE + + +def can_create_session(): + """ + Temporary function to check if the database is available to create a session + During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created + """ + return dependencies_available() and Session is not None + + +def get_alembic_config(): + root_path = os.path.join(os.path.dirname(__file__), "../..") + config_path = os.path.abspath(os.path.join(root_path, "alembic.ini")) + scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) + + config = Config(config_path) + config.set_main_option("script_location", scripts_path) + config.set_main_option("sqlalchemy.url", args.database_url) + + return config + + +def get_db_path(): + url = args.database_url + if url.startswith("sqlite:///"): + return url.split("///")[1] + else: + raise ValueError(f"Unsupported database URL '{url}'.") + + +def init_db(): + db_url = args.database_url + logging.debug(f"Database URL: {db_url}") + db_path = get_db_path() + db_exists = os.path.exists(db_path) + + config = get_alembic_config() + + # Check if we need to upgrade + engine = create_engine(db_url) + conn = engine.connect() + + context = MigrationContext.configure(conn) + current_rev = context.get_current_revision() + + script = ScriptDirectory.from_config(config) + target_rev = script.get_current_head() + + if target_rev is None: + logging.warning("No target revision found.") + elif current_rev != target_rev: + # Backup the database pre upgrade + backup_path = db_path + ".bkp" + if db_exists: + shutil.copy(db_path, backup_path) + else: + backup_path = None + + try: + command.upgrade(config, target_rev) + logging.info(f"Database upgraded from {current_rev} to {target_rev}") + except Exception as e: + if backup_path: + # Restore the database from backup if upgrade fails + shutil.copy(backup_path, db_path) + os.remove(backup_path) + logging.exception("Error upgrading database: ") + raise e + + global Session + Session = sessionmaker(bind=engine) + + +def create_session(): + return Session() diff --git a/app/database/models.py b/app/database/models.py new file mode 100644 index 000000000..6facfb8f2 --- /dev/null +++ b/app/database/models.py @@ -0,0 +1,14 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +def to_dict(obj): + fields = obj.__table__.columns.keys() + return { + field: (val.to_dict() if hasattr(val, "to_dict") else val) + for field in fields + if (val := getattr(obj, field)) + } + +# TODO: Define models here diff --git a/app/frontend_management.py b/app/frontend_management.py index 7b7923b79..001ebbecb 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -16,26 +16,17 @@ from importlib.metadata import version import requests from typing_extensions import NotRequired +from utils.install_util import get_missing_requirements_message, requirements_path + from comfy.cli_args import DEFAULT_VERSION_STRING import app.logger -# The path to the requirements.txt file -req_path = Path(__file__).parents[1] / "requirements.txt" - def frontend_install_warning_message(): - """The warning message to display when the frontend version is not up to date.""" - - extra = "" - if sys.flags.no_user_site: - extra = "-s " return f""" -Please install the updated requirements.txt file by running: -{sys.executable} {extra}-m pip install -r {req_path} +{get_missing_requirements_message()} This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead. - -If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem """.strip() @@ -48,7 +39,7 @@ def check_frontend_version(): try: frontend_version_str = version("comfyui-frontend-package") frontend_version = parse_version(frontend_version_str) - with open(req_path, "r", encoding="utf-8") as f: + with open(requirements_path, "r", encoding="utf-8") as f: required_frontend = parse_version(f.readline().split("=")[-1]) if frontend_version < required_frontend: app.logger.log_startup_warning( @@ -121,9 +112,22 @@ class FrontEndProvider: response.raise_for_status() # Raises an HTTPError if the response was an error return response.json() + @cached_property + def latest_prerelease(self) -> Release: + """Get the latest pre-release version - even if it's older than the latest release""" + release = [release for release in self.all_releases if release["prerelease"]] + + if not release: + raise ValueError("No pre-releases found") + + # GitHub returns releases in reverse chronological order, so first is latest + return release[0] + def get_release(self, version: str) -> Release: if version == "latest": return self.latest_release + elif version == "prerelease": + return self.latest_prerelease else: for release in self.all_releases: if release["tag_name"] in [version, f"v{version}"]: @@ -205,6 +209,19 @@ comfyui-workflow-templates is not installed. """.strip() ) + @classmethod + def embedded_docs_path(cls) -> str: + """Get the path to embedded documentation""" + try: + import comfyui_embedded_docs + + return str( + importlib.resources.files(comfyui_embedded_docs) / "docs" + ) + except ImportError: + logging.info("comfyui-embedded-docs package not found") + return None + @classmethod def parse_version_string(cls, value: str) -> tuple[str, str, str]: """ @@ -217,7 +234,7 @@ comfyui-workflow-templates is not installed. Raises: argparse.ArgumentTypeError: If the version string is invalid. """ - VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$" + VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+[-._a-zA-Z0-9]*|latest|prerelease)$" match_result = re.match(VERSION_PATTERN, value) if match_result is None: raise argparse.ArgumentTypeError(f"Invalid version string: {value}") diff --git a/comfy/cli_args.py b/comfy/cli_args.py index d5ecdf89c..1609ba308 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -88,6 +88,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE" parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.") parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.") +parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.") class LatentPreviewMethod(enum.Enum): NoPreviews = "none" @@ -202,6 +203,11 @@ parser.add_argument( help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)", ) +database_default_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") +) +parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") + if comfy.options.args_parsing: args = parser.parse_args() else: diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 2ffc9c021..071b98332 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -37,6 +37,8 @@ class IO(StrEnum): CONTROL_NET = "CONTROL_NET" VAE = "VAE" MODEL = "MODEL" + LORA_MODEL = "LORA_MODEL" + LOSS_MAP = "LOSS_MAP" CLIP_VISION = "CLIP_VISION" CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT" STYLE_MODEL = "STYLE_MODEL" @@ -235,7 +237,7 @@ class ComfyNodeABC(ABC): DEPRECATED: bool """Flags a node as deprecated, indicating to users that they should find alternatives to this node.""" API_NODE: Optional[bool] - """Flags a node as an API node.""" + """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" @classmethod @abstractmethod diff --git a/comfy/conds.py b/comfy/conds.py index 211fb8d57..2af2a43a3 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -24,6 +24,10 @@ class CONDRegular: conds.append(x.cond) return torch.cat(conds) + def size(self): + return list(self.cond.size()) + + class CONDNoiseShape(CONDRegular): def process_cond(self, batch_size, device, area, **kwargs): data = self.cond @@ -64,6 +68,7 @@ class CONDCrossAttn(CONDRegular): out.append(c) return torch.cat(out) + class CONDConstant(CONDRegular): def __init__(self, cond): self.cond = cond @@ -78,3 +83,48 @@ class CONDConstant(CONDRegular): def concat(self, others): return self.cond + + def size(self): + return [1] + + +class CONDList(CONDRegular): + def __init__(self, cond): + self.cond = cond + + def process_cond(self, batch_size, device, **kwargs): + out = [] + for c in self.cond: + out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device)) + + return self._copy_with(out) + + def can_concat(self, other): + if len(self.cond) != len(other.cond): + return False + for i in range(len(self.cond)): + if self.cond[i].shape != other.cond[i].shape: + return False + + return True + + def concat(self, others): + out = [] + for i in range(len(self.cond)): + o = [self.cond[i]] + for x in others: + o.append(x.cond[i]) + out.append(torch.cat(o)) + + return out + + def size(self): # hackish implementation to make the mem estimation work + o = 0 + c = 1 + for c in self.cond: + size = c.size() + o += math.prod(size) + if len(size) > 1: + c = size[1] + + return [1, c, o // c] diff --git a/comfy/controlnet.py b/comfy/controlnet.py index d1ffeea36..c782c827f 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -433,8 +433,9 @@ class ControlLora(ControlNet): pass for k in self.control_weights: - if k not in {"lora_controlnet"}: - comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) + if (k not in {"lora_controlnet"}): + if (k.endswith(".up") or k.endswith(".down") or k.endswith(".weight") or k.endswith(".bias")) and ("__" not in k): + comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) def copy(self): c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index fbdf6f554..8030048fc 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,4 +1,5 @@ import math +from functools import partial from scipy import integrate import torch @@ -142,6 +143,33 @@ class BrownianTreeNoiseSampler: return self.tree(t0, t1) / (t1 - t0).abs().sqrt() +def sigma_to_half_log_snr(sigma, model_sampling): + """Convert sigma to half-logSNR log(alpha_t / sigma_t).""" + if isinstance(model_sampling, comfy.model_sampling.CONST): + # log((1 - t) / t) = log((1 - sigma) / sigma) + return sigma.logit().neg() + return sigma.log().neg() + + +def half_log_snr_to_sigma(half_log_snr, model_sampling): + """Convert half-logSNR log(alpha_t / sigma_t) to sigma.""" + if isinstance(model_sampling, comfy.model_sampling.CONST): + # 1 / (1 + exp(half_log_snr)) + return half_log_snr.neg().sigmoid() + return half_log_snr.neg().exp() + + +def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4): + """Adjust the first sigma to avoid invalid logSNR.""" + if len(sigmas) <= 1: + return sigmas + if isinstance(model_sampling, comfy.model_sampling.CONST): + if sigmas[0] >= 1: + sigmas = sigmas.clone() + sigmas[0] = model_sampling.percent_to_sigma(percent_offset) + return sigmas + + @torch.no_grad() def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" @@ -753,6 +781,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No old_denoised = denoised return x + @torch.no_grad() def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): """DPM-Solver++(2M) SDE.""" @@ -768,9 +797,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + old_denoised = None - h_last = None - h = None + h, h_last = None, None for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -781,26 +813,29 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl x = denoised else: # DPM-Solver++(2M) SDE - t, s = -sigmas[i].log(), -sigmas[i + 1].log() - h = s - t - eta_h = eta * h + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + h_eta = h * (eta + 1) - x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised + alpha_t = sigmas[i + 1] * lambda_t.exp() + + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised if old_denoised is not None: r = h_last / h if solver_type == 'heun': - x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) + x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised) elif solver_type == 'midpoint': - x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) + x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised) - if eta: - x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + if eta > 0 and s_noise > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise old_denoised = denoised h_last = h return x + @torch.no_grad() def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """DPM-Solver++(3M) SDE.""" @@ -814,6 +849,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + denoised_1, denoised_2 = None, None h, h_1, h_2 = None, None, None @@ -825,13 +864,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl # Denoising step x = denoised else: - t, s = -sigmas[i].log(), -sigmas[i + 1].log() - h = s - t + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s h_eta = h * (eta + 1) - x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised + alpha_t = sigmas[i + 1] * lambda_t.exp() + + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised if h_2 is not None: + # DPM-Solver++(3M) SDE r0 = h_1 / h r1 = h_2 / h d1_0 = (denoised - denoised_1) / r0 @@ -840,20 +882,22 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl d2 = (d1_0 - d1_1) / (r0 + r1) phi_2 = h_eta.neg().expm1() / h_eta + 1 phi_3 = phi_2 / h_eta - 0.5 - x = x + phi_2 * d1 - phi_3 * d2 + x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2 elif h_1 is not None: + # DPM-Solver++(2M) SDE r = h_1 / h d = (denoised - denoised_1) / r phi_2 = h_eta.neg().expm1() / h_eta + 1 - x = x + phi_2 * d + x = x + (alpha_t * phi_2) * d - if eta: + if eta > 0 and s_noise > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise denoised_1, denoised_2 = denoised, denoised_1 h_1, h_2 = h, h_1 return x + @torch.no_grad() def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): if len(sigmas) <= 1: @@ -863,6 +907,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) + @torch.no_grad() def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): if len(sigmas) <= 1: @@ -872,6 +917,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + @torch.no_grad() def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): if len(sigmas) <= 1: @@ -1449,12 +1495,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None old_denoised = denoised return x + @torch.no_grad() def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): - ''' - SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2 - Arxiv: https://arxiv.org/abs/2305.14267 - ''' + """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. + arXiv: https://arxiv.org/abs/2305.14267 + """ extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler @@ -1462,6 +1508,11 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non inject_noise = eta > 0 and s_noise > 0 + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: @@ -1469,80 +1520,96 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non if sigmas[i + 1] == 0: x = denoised else: - t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() - h = t_next - t + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s h_eta = h * (eta + 1) - s = t + r * h + lambda_s_1 = lambda_s + r * h fac = 1 / (2 * r) - sigma_s = s.neg().exp() + sigma_s_1 = sigma_fn(lambda_s_1) + + # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1() if inject_noise: + # 0 < r < 1 noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt() - noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() - noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1]) + noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt() + noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1]) # Step 1 - x_2 = (coeff_1 + 1) * x - coeff_1 * denoised - if inject_noise: - x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise - denoised_2 = model(x_2, sigma_s * s_in, **extra_args) - - # Step 2 - denoised_d = (1 - fac) * denoised + fac * denoised_2 - x = (coeff_2 + 1) * x - coeff_2 * denoised_d - if inject_noise: - x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise - return x - -@torch.no_grad() -def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): - ''' - SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3 - Arxiv: https://arxiv.org/abs/2305.14267 - ''' - extra_args = {} if extra_args is None else extra_args - seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_in = x.new_ones([x.shape[0]]) - - inject_noise = eta > 0 and s_noise > 0 - - for i in trange(len(sigmas) - 1, disable=disable): - denoised = model(x, sigmas[i] * s_in, **extra_args) - if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - if sigmas[i + 1] == 0: - x = denoised - else: - t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() - h = t_next - t - h_eta = h * (eta + 1) - s_1 = t + r_1 * h - s_2 = t + r_2 * h - sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp() - - coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() - if inject_noise: - noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() - noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt() - noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() - noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) - - # Step 1 - x_2 = (coeff_1 + 1) * x - coeff_1 * denoised + x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised if inject_noise: x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) # Step 2 - x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) + denoised_d = (1 - fac) * denoised + fac * denoised_2 + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d + if inject_noise: + x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise + return x + + +@torch.no_grad() +def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): + """SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3. + arXiv: https://arxiv.org/abs/2305.14267 + """ + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + inject_noise = eta > 0 and s_noise > 0 + + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + x = denoised + else: + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + h_eta = h * (eta + 1) + lambda_s_1 = lambda_s + r_1 * h + lambda_s_2 = lambda_s + r_2 * h + sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2) + + # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_s_2 = sigma_s_2 * lambda_s_2.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() + + coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() + if inject_noise: + # 0 < r_1 < r_2 < 1 + noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() + noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt() + noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt() + noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) + + # Step 1 + x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised + if inject_noise: + x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise + denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) + + # Step 2 + x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) if inject_noise: x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) # Step 3 - x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) if inject_noise: x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise return x diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 35da91ee2..2a0dec606 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -80,15 +80,13 @@ class DoubleStreamBlock(nn.Module): (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec # prepare image for attention - img_modulated = self.img_norm1(img) - img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img)) img_qkv = self.img_attn.qkv(img_modulated) img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # prepare txt for attention - txt_modulated = self.txt_norm1(txt) - txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt)) txt_qkv = self.txt_attn.qkv(txt_modulated) txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) @@ -102,12 +100,12 @@ class DoubleStreamBlock(nn.Module): txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img bloks - img = img + img_mod1.gate * self.img_attn.proj(img_attn) - img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn)) + img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img)))) # calculate the txt bloks - txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) - txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn)) + txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt)))) if txt.dtype == torch.float16: txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) @@ -152,7 +150,7 @@ class SingleStreamBlock(nn.Module): def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor: mod = vec - x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x)) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) @@ -162,7 +160,7 @@ class SingleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe, mask=attn_mask) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - x += mod.gate * output + x.addcmul_(mod.gate, output) if x.dtype == torch.float16: x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) return x @@ -178,6 +176,6 @@ class LastLayer(nn.Module): shift, scale = vec shift = shift.squeeze(1) scale = scale.squeeze(1) - x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x)) x = self.linear(x) return x diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 636748fc5..c75023a31 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -163,7 +163,7 @@ class Chroma(nn.Module): distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype) # get all modulation index - modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(img.device, img.dtype) + modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype) # we need to broadcast the modulation index here so each batch has all of the index modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype) # and we need to broadcast timestep and guidance along too diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index a12f892d2..5c4356a3f 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -26,16 +26,6 @@ from torch import nn from comfy.ldm.modules.attention import optimized_attention -def apply_rotary_pos_emb( - t: torch.Tensor, - freqs: torch.Tensor, -) -> torch.Tensor: - t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() - t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] - t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) - return t_out - - def get_normalization(name: str, channels: int, weight_args={}, operations=None): if name == "I": return nn.Identity() diff --git a/comfy/ldm/cosmos/position_embedding.py b/comfy/ldm/cosmos/position_embedding.py index 4d6a58dba..c925811d4 100644 --- a/comfy/ldm/cosmos/position_embedding.py +++ b/comfy/ldm/cosmos/position_embedding.py @@ -66,15 +66,16 @@ class VideoRopePosition3DEmb(VideoPositionEmb): h_extrapolation_ratio: float = 1.0, w_extrapolation_ratio: float = 1.0, t_extrapolation_ratio: float = 1.0, + enable_fps_modulation: bool = True, device=None, **kwargs, # used for compatibility with other positional embeddings; unused in this class ): del kwargs super().__init__() - self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device)) self.base_fps = base_fps self.max_h = len_h self.max_w = len_w + self.enable_fps_modulation = enable_fps_modulation dim = head_dim dim_h = dim // 6 * 2 @@ -132,21 +133,19 @@ class VideoRopePosition3DEmb(VideoPositionEmb): temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device)) B, T, H, W, _ = B_T_H_W_C + seq = torch.arange(max(H, W, T), dtype=torch.float, device=device) uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max()) assert ( uniform_fps or B == 1 or T == 1 ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" - assert ( - H <= self.max_h and W <= self.max_w - ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})" - half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs) - half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs) + half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs) + half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs) # apply sequence scaling in temporal dimension - if fps is None: # image case - half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs) + if fps is None or self.enable_fps_modulation is False: # image case + half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs) else: - half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs) + half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs) half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1) half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1) diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py new file mode 100644 index 000000000..316117f77 --- /dev/null +++ b/comfy/ldm/cosmos/predict2.py @@ -0,0 +1,864 @@ +# original code from: https://github.com/nvidia-cosmos/cosmos-predict2 + +import torch +from torch import nn +from einops import rearrange +from einops.layers.torch import Rearrange +import logging +from typing import Callable, Optional, Tuple +import math + +from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis +from torchvision import transforms + +from comfy.ldm.modules.attention import optimized_attention + +def apply_rotary_pos_emb( + t: torch.Tensor, + freqs: torch.Tensor, +) -> torch.Tensor: + t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() + t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] + t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) + return t_out + + +# ---------------------- Feed Forward Network ----------------------- +class GPT2FeedForward(nn.Module): + def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None: + super().__init__() + self.activation = nn.GELU() + self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype) + self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype) + + self._layer_id = None + self._dim = d_model + self._hidden_dim = d_ff + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layer1(x) + + x = self.activation(x) + x = self.layer2(x) + return x + + +def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor: + """Computes multi-head attention using PyTorch's native implementation. + + This function provides a PyTorch backend alternative to Transformer Engine's attention operation. + It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product + attention, and rearranges the output back to the original format. + + The input tensor names use the following dimension conventions: + + - B: batch size + - S: sequence length + - H: number of attention heads + - D: head dimension + + Args: + q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim) + k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim) + v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim) + + Returns: + Attention output tensor with shape (batch, seq_len, n_heads * head_dim) + """ + in_q_shape = q_B_S_H_D.shape + in_k_shape = k_B_S_H_D.shape + q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) + k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True) + + +class Attention(nn.Module): + """ + A flexible attention module supporting both self-attention and cross-attention mechanisms. + + This module implements a multi-head attention layer that can operate in either self-attention + or cross-attention mode. The mode is determined by whether a context dimension is provided. + The implementation uses scaled dot-product attention and supports optional bias terms and + dropout regularization. + + Args: + query_dim (int): The dimensionality of the query vectors. + context_dim (int, optional): The dimensionality of the context (key/value) vectors. + If None, the module operates in self-attention mode using query_dim. Default: None + n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8 + head_dim (int, optional): The dimension of each attention head. Default: 64 + dropout (float, optional): Dropout probability applied to the output. Default: 0.0 + qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd" + backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine" + + Examples: + >>> # Self-attention with 512 dimensions and 8 heads + >>> self_attn = Attention(query_dim=512) + >>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim) + >>> out = self_attn(x) # (32, 16, 512) + + >>> # Cross-attention + >>> cross_attn = Attention(query_dim=512, context_dim=256) + >>> query = torch.randn(32, 16, 512) + >>> context = torch.randn(32, 8, 256) + >>> out = cross_attn(query, context) # (32, 16, 512) + """ + + def __init__( + self, + query_dim: int, + context_dim: Optional[int] = None, + n_heads: int = 8, + head_dim: int = 64, + dropout: float = 0.0, + device=None, + dtype=None, + operations=None, + ) -> None: + super().__init__() + logging.debug( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{n_heads} heads with a dimension of {head_dim}." + ) + self.is_selfattn = context_dim is None # self attention + + context_dim = query_dim if context_dim is None else context_dim + inner_dim = head_dim * n_heads + + self.n_heads = n_heads + self.head_dim = head_dim + self.query_dim = query_dim + self.context_dim = context_dim + + self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.v_norm = nn.Identity() + + self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) + self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity() + + self.attn_op = torch_attention_op + + self._query_dim = query_dim + self._context_dim = context_dim + self._inner_dim = inner_dim + + def compute_qkv( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + rope_emb: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = self.q_proj(x) + context = x if context is None else context + k = self.k_proj(context) + v = self.v_proj(context) + q, k, v = map( + lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim), + (q, k, v), + ) + + def apply_norm_and_rotary_pos_emb( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = self.q_norm(q) + k = self.k_norm(k) + v = self.v_norm(v) + if self.is_selfattn and rope_emb is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb) + k = apply_rotary_pos_emb(k, rope_emb) + return q, k, v + + q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb) + + return q, k, v + + def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + result = self.attn_op(q, k, v) # [B, S, H, D] + return self.output_dropout(self.output_proj(result)) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + rope_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (Tensor): The query tensor of shape [B, Mq, K] + context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None + """ + q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) + return self.compute_attention(q, k, v) + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor: + assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}" + timesteps = timesteps_B_T.flatten().float() + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1]) + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None): + super().__init__() + logging.debug( + f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." + ) + self.in_dim = in_features + self.out_dim = out_features + self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype) + self.activation = nn.SiLU() + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype) + else: + self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype) + + def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + emb = self.linear_1(sample) + emb = self.activation(emb) + emb = self.linear_2(emb) + + if self.use_adaln_lora: + adaln_lora_B_T_3D = emb + emb_B_T_D = sample + else: + adaln_lora_B_T_3D = None + emb_B_T_D = emb + + return emb_B_T_D, adaln_lora_B_T_3D + + +class PatchEmbed(nn.Module): + """ + PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, + depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, + making it suitable for video and image processing tasks. It supports dividing the input into patches + and embedding each patch into a vector of size `out_channels`. + + Parameters: + - spatial_patch_size (int): The size of each spatial patch. + - temporal_patch_size (int): The size of each temporal patch. + - in_channels (int): Number of input channels. Default: 3. + - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. + - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. + """ + + def __init__( + self, + spatial_patch_size: int, + temporal_patch_size: int, + in_channels: int = 3, + out_channels: int = 768, + device=None, dtype=None, operations=None + ): + super().__init__() + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + + self.proj = nn.Sequential( + Rearrange( + "b c (t r) (h m) (w n) -> b t h w (c r m n)", + r=temporal_patch_size, + m=spatial_patch_size, + n=spatial_patch_size, + ), + operations.Linear( + in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype + ), + ) + self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the PatchEmbed module. + + Parameters: + - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where + B is the batch size, + C is the number of channels, + T is the temporal dimension, + H is the height, and + W is the width of the input. + + Returns: + - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. + """ + assert x.dim() == 5 + _, _, T, H, W = x.shape + assert ( + H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 + ), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}" + assert T % self.temporal_patch_size == 0 + x = self.proj(x) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of video DiT. + """ + + def __init__( + self, + hidden_size: int, + spatial_patch_size: int, + temporal_patch_size: int, + out_channels: int, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + device=None, dtype=None, operations=None + ): + super().__init__() + self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = operations.Linear( + hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype + ) + self.hidden_size = hidden_size + self.n_adaln_chunks = 2 + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + if use_adaln_lora: + self.adaln_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype), + ) + else: + self.adaln_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype) + ) + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_T_3D is not None + shift_B_T_D, scale_B_T_D = ( + self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size] + ).chunk(2, dim=-1) + else: + shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) + + shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange( + scale_B_T_D, "b t d -> b t 1 1 d" + ) + + def _fn( + _x_B_T_H_W_D: torch.Tensor, + _norm_layer: nn.Module, + _scale_B_T_1_1_D: torch.Tensor, + _shift_B_T_1_1_D: torch.Tensor, + ) -> torch.Tensor: + return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D + + x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D) + x_B_T_H_W_O = self.linear(x_B_T_H_W_D) + return x_B_T_H_W_O + + +class Block(nn.Module): + """ + A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation. + Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation. + + Parameters: + x_dim (int): Dimension of input features + context_dim (int): Dimension of context features for cross-attention + num_heads (int): Number of attention heads + mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0 + use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False + adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256 + + The block applies the following sequence: + 1. Self-attention with AdaLN modulation + 2. Cross-attention with AdaLN modulation + 3. MLP with AdaLN modulation + + Each component uses skip connections and layer normalization. + """ + + def __init__( + self, + x_dim: int, + context_dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + self.x_dim = x_dim + self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations) + + self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.cross_attn = Attention( + x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations + ) + + self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations) + + self.use_adaln_lora = use_adaln_lora + if self.use_adaln_lora: + self.adaln_modulation_self_attn = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + self.adaln_modulation_cross_attn = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + self.adaln_modulation_mlp = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + else: + self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + crossattn_emb: torch.Tensor, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if extra_per_block_pos_emb is not None: + x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb + + if self.use_adaln_lora: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = ( + self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( + self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = ( + self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + else: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1) + + # Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting + shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d") + scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d") + gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d") + + shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d") + scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d") + gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d") + + shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d") + scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d") + gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d") + + B, T, H, W, D = x_B_T_H_W_D.shape + + def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D): + return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D + + normalized_x_B_T_H_W_D = _fn( + x_B_T_H_W_D, + self.layer_norm_self_attn, + scale_self_attn_B_T_1_1_D, + shift_self_attn_B_T_1_1_D, + ) + result_B_T_H_W_D = rearrange( + self.self_attn( + # normalized_x_B_T_HW_D, + rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + None, + rope_emb=rope_emb_L_1_1_D, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D + + def _x_fn( + _x_B_T_H_W_D: torch.Tensor, + layer_norm_cross_attn: Callable, + _scale_cross_attn_B_T_1_1_D: torch.Tensor, + _shift_cross_attn_B_T_1_1_D: torch.Tensor, + ) -> torch.Tensor: + _normalized_x_B_T_H_W_D = _fn( + _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D + ) + _result_B_T_H_W_D = rearrange( + self.cross_attn( + rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + crossattn_emb, + rope_emb=rope_emb_L_1_1_D, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + return _result_B_T_H_W_D + + result_B_T_H_W_D = _x_fn( + x_B_T_H_W_D, + self.layer_norm_cross_attn, + scale_cross_attn_B_T_1_1_D, + shift_cross_attn_B_T_1_1_D, + ) + x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D + + normalized_x_B_T_H_W_D = _fn( + x_B_T_H_W_D, + self.layer_norm_mlp, + scale_mlp_B_T_1_1_D, + shift_mlp_B_T_1_1_D, + ) + result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D) + x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D + return x_B_T_H_W_D + + +class MiniTrainDIT(nn.Module): + """ + A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1) + A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. + + Args: + max_img_h (int): Maximum height of the input images. + max_img_w (int): Maximum width of the input images. + max_frames (int): Maximum number of frames in the video sequence. + in_channels (int): Number of input channels (e.g., RGB channels for color images). + out_channels (int): Number of output channels. + patch_spatial (tuple): Spatial resolution of patches for input processing. + patch_temporal (int): Temporal resolution of patches for input processing. + concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. + model_channels (int): Base number of channels used throughout the model. + num_blocks (int): Number of transformer blocks. + num_heads (int): Number of heads in the multi-head attention layers. + mlp_ratio (float): Expansion ratio for MLP blocks. + crossattn_emb_channels (int): Number of embedding channels for cross-attention. + pos_emb_cls (str): Type of positional embeddings. + pos_emb_learnable (bool): Whether positional embeddings are learnable. + pos_emb_interpolation (str): Method for interpolating positional embeddings. + min_fps (int): Minimum frames per second. + max_fps (int): Maximum frames per second. + use_adaln_lora (bool): Whether to use AdaLN-LoRA. + adaln_lora_dim (int): Dimension for AdaLN-LoRA. + rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. + rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. + rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. + extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. + extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. + extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. + extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. + """ + + def __init__( + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: int, # tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + mlp_ratio: float = 4.0, + # cross attention settings + crossattn_emb_channels: int = 1024, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + min_fps: int = 1, + max_fps: int = 30, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = False, + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + rope_enable_fps_modulation: bool = True, + image_model=None, + device=None, + dtype=None, + operations=None, + ) -> None: + super().__init__() + self.dtype = dtype + self.max_img_h = max_img_h + self.max_img_w = max_img_w + self.max_frames = max_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.num_heads = num_heads + self.num_blocks = num_blocks + self.model_channels = model_channels + self.concat_padding_mask = concat_padding_mask + # positional embedding settings + self.pos_emb_cls = pos_emb_cls + self.pos_emb_learnable = pos_emb_learnable + self.pos_emb_interpolation = pos_emb_interpolation + self.min_fps = min_fps + self.max_fps = max_fps + self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio + self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio + self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio + self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb + self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio + self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio + self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio + self.rope_enable_fps_modulation = rope_enable_fps_modulation + + self.build_pos_embed(device=device, dtype=dtype) + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + self.t_embedder = nn.Sequential( + Timesteps(model_channels), + TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,), + ) + + in_channels = in_channels + 1 if concat_padding_mask else in_channels + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + device=device, dtype=dtype, operations=operations, + ) + + self.blocks = nn.ModuleList( + [ + Block( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + device=device, dtype=dtype, operations=operations, + ) + for _ in range(num_blocks) + ] + ) + + self.final_layer = FinalLayer( + hidden_size=self.model_channels, + spatial_patch_size=self.patch_spatial, + temporal_patch_size=self.patch_temporal, + out_channels=self.out_channels, + use_adaln_lora=self.use_adaln_lora, + adaln_lora_dim=self.adaln_lora_dim, + device=device, dtype=dtype, operations=operations, + ) + + self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype) + + def build_pos_embed(self, device=None, dtype=None) -> None: + if self.pos_emb_cls == "rope3d": + cls_type = VideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + max_fps=self.max_fps, + min_fps=self.min_fps, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + enable_fps_modulation=self.rope_enable_fps_modulation, + device=device, + ) + self.pos_embedder = cls_type( + **kwargs, # type: ignore + ) + + if self.extra_per_block_abs_pos_emb: + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + kwargs["device"] = device + kwargs["dtype"] = dtype + self.extra_pos_embedder = LearnablePosEmbAxis( + **kwargs, # type: ignore + ) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the + `self.pos_embedder` with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + if padding_mask is None: + padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device) + else: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D] + + return x_B_T_H_W_D, None, extra_pos_emb + + def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor: + x_B_C_Tt_Hp_Wp = rearrange( + x_B_T_H_W_M, + "B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + t=self.patch_temporal, + ) + return x_B_C_Tt_Hp_Wp + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + x_B_C_T_H_W = x + timesteps_B_T = timesteps + crossattn_emb = context + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + """ + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x_B_C_T_H_W, + fps=fps, + padding_mask=padding_mask, + ) + + if timesteps_B_T.ndim == 1: + timesteps_B_T = timesteps_B_T.unsqueeze(1) + t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype)) + t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D) + + # for logging purpose + affline_scale_log_info = {} + affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach() + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = t_embedding_B_T_D + self.crossattn_emb = crossattn_emb + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + assert ( + x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + ), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}" + + block_kwargs = { + "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0), + "adaln_lora_B_T_3D": adaln_lora_B_T_3D, + "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + for block in self.blocks: + x_B_T_H_W_D = block( + x_B_T_H_W_D, + t_embedding_B_T_D, + crossattn_emb, + **block_kwargs, + ) + + x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) + x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O) + return x_B_C_Tt_Hp_Wp diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py index 5322c4891..dbd2a47c0 100644 --- a/comfy/ldm/flux/controlnet.py +++ b/comfy/ldm/flux/controlnet.py @@ -121,6 +121,9 @@ class ControlNetFlux(Flux): if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + # running on sequences img img = self.img_in(img) @@ -174,7 +177,7 @@ class ControlNetFlux(Flux): out["output"] = out_output[:self.main_model_single] return out - def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): + def forward(self, x, timesteps, context, y=None, guidance=None, hint=None, **kwargs): patch_size = 2 if self.latent_input: hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size)) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index ef4ba4106..846703d52 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -101,6 +101,10 @@ class Flux(nn.Module): transformer_options={}, attn_mask: Tensor = None, ) -> Tensor: + + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -155,6 +159,9 @@ class Flux(nn.Module): if add is not None: img += add + if img.dtype == torch.float16: + img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504) + img = torch.cat((txt, img), 1) for i, block in enumerate(self.single_blocks): @@ -188,7 +195,7 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y=None, guidance=None, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape patch_size = self.patch_size x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 72af3d5bb..fbd8d4196 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -228,6 +228,7 @@ class HunyuanVideo(nn.Module): y: Tensor, guidance: Tensor = None, guiding_frame_index=None, + ref_latent=None, control=None, transformer_options={}, ) -> Tensor: @@ -238,6 +239,14 @@ class HunyuanVideo(nn.Module): img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) + if ref_latent is not None: + ref_latent_ids = self.img_ids(ref_latent) + ref_latent = self.img_in(ref_latent) + img = torch.cat([ref_latent, img], dim=-2) + ref_latent_ids[..., 0] = -1 + ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1]) + img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2) + if guiding_frame_index is not None: token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) @@ -313,6 +322,8 @@ class HunyuanVideo(nn.Module): img[:, : img_len] += add img = img[:, : img_len] + if ref_latent is not None: + img = img[:, ref_latent.shape[1]:] img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) @@ -324,7 +335,7 @@ class HunyuanVideo(nn.Module): img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) return img - def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs): + def img_ids(self, x): bs, c, t, h, w = x.shape patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) @@ -334,7 +345,11 @@ class HunyuanVideo(nn.Module): img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) - img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) + return repeat(img_ids, "t h w c -> b (t h w) c", b=bs) + + def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + bs, c, t, h, w = x.shape + img_ids = self.img_ids(x) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options) return out diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 056e101a4..ad9a7daea 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -261,8 +261,8 @@ class CrossAttention(nn.Module): self.heads = heads self.dim_head = dim_head - self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device) - self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device) + self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 45f9e311e..35d2270ee 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -20,8 +20,11 @@ if model_management.xformers_enabled(): if model_management.sage_attention_enabled(): try: from sageattention import sageattn - except ModuleNotFoundError: - logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") + except ModuleNotFoundError as e: + if e.name == "sageattention": + logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") + else: + raise e exit(-1) if model_management.flash_attention_enabled(): @@ -750,7 +753,7 @@ class BasicTransformerBlock(nn.Module): for p in patch: n = p(n, extra_options) - x += n + x = n + x if "middle_patch" in transformer_patches: patch = transformer_patches["middle_patch"] for p in patch: @@ -790,12 +793,12 @@ class BasicTransformerBlock(nn.Module): for p in patch: n = p(n, extra_options) - x += n + x = n + x if self.is_res: x_skip = x x = self.ff(self.norm3(x)) if self.is_res: - x += x_skip + x = x_skip + x return x diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index fc5ff40c5..1d6edb354 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -247,6 +247,60 @@ class VaceWanAttentionBlock(WanAttentionBlock): return c_skip, c +class WanCamAdapter(nn.Module): + def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1, operation_settings={}): + super(WanCamAdapter, self).__init__() + + # Pixel Unshuffle: reduce spatial dimensions by a factor of 8 + self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8) + + # Convolution: reduce spatial dimensions by a factor + # of 2 (without overlap) + self.conv = operation_settings.get("operations").Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + # Residual blocks for feature extraction + self.residual_blocks = nn.Sequential( + *[WanCamResidualBlock(out_dim, operation_settings = operation_settings) for _ in range(num_residual_blocks)] + ) + + def forward(self, x): + # Reshape to merge the frame dimension into batch + bs, c, f, h, w = x.size() + x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w) + + # Pixel Unshuffle operation + x_unshuffled = self.pixel_unshuffle(x) + + # Convolution operation + x_conv = self.conv(x_unshuffled) + + # Feature extraction with residual blocks + out = self.residual_blocks(x_conv) + + # Reshape to restore original bf dimension + out = out.view(bs, f, out.size(1), out.size(2), out.size(3)) + + # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames + out = out.permute(0, 2, 1, 3, 4) + + return out + + +class WanCamResidualBlock(nn.Module): + def __init__(self, dim, operation_settings={}): + super(WanCamResidualBlock, self).__init__() + self.conv1 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.relu = nn.ReLU(inplace=True) + self.conv2 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x): + residual = x + out = self.relu(self.conv1(x)) + out = self.conv2(out) + out += residual + return out + + class Head(nn.Module): def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}): @@ -485,13 +539,20 @@ class WanModel(torch.nn.Module): x = self.unpatchify(x, grid_sizes) return x - def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) + + if time_dim_concat is not None: + time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) + x = torch.cat([x, time_dim_concat], dim=2) + t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0]) + img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) @@ -581,7 +642,7 @@ class VaceWanModel(WanModel): t, context, vace_context, - vace_strength=1.0, + vace_strength, clip_fea=None, freqs=None, transformer_options={}, @@ -607,8 +668,11 @@ class VaceWanModel(WanModel): context = torch.concat([context_clip, context], dim=1) context_img_len = clip_fea.shape[-2] + orig_shape = list(vace_context.shape) + vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:]) c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype) c = c.flatten(2).transpose(1, 2) + c = list(c.split(orig_shape[0], dim=0)) # arguments x_orig = x @@ -628,8 +692,9 @@ class VaceWanModel(WanModel): ii = self.vace_layers_mapping.get(i, None) if ii is not None: - c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) - x += c_skip * vace_strength + for iii in range(len(c)): + c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x += c_skip * vace_strength[iii] del c_skip # head x = self.head(x, e) @@ -637,3 +702,92 @@ class VaceWanModel(WanModel): # unpatchify x = self.unpatchify(x, grid_sizes) return x + +class CameraWanModel(WanModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__(self, + model_type='camera', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flf_pos_embed_token_number=None, + image_model=None, + in_dim_control_adapter=24, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings) + + + def forward_orig( + self, + x, + t, + context, + clip_fea=None, + freqs=None, + camera_conditions = None, + transformer_options={}, + **kwargs, + ): + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + if self.control_adapter is not None and camera_conditions is not None: + x_camera = self.control_adapter(camera_conditions).to(x.dtype) + x = x + x_camera + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x diff --git a/comfy/lora.py b/comfy/lora.py index fff524be2..387d5c52a 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -283,8 +283,15 @@ def model_lora_keys_unet(model, key_map={}): for k in sdk: if k.startswith("diffusion_model."): if k.endswith(".weight"): - key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") - key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format + key_lora = k[len("diffusion_model."):-len(".weight")] + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format + key_map["transformer.{}".format(key_lora)] = k #SimpleTuner regular format + + if isinstance(model, comfy.model_base.ACEStep): + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format + key_lora = k[len("diffusion_model."):-len(".weight")] + key_map["{}".format(key_lora)] = k return key_map diff --git a/comfy/model_base.py b/comfy/model_base.py index 6d27930dc..cb7689e84 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -34,6 +34,7 @@ import comfy.ldm.flux.model import comfy.ldm.lightricks.model import comfy.ldm.hunyuan_video.model import comfy.ldm.cosmos.model +import comfy.ldm.cosmos.predict2 import comfy.ldm.lumina.model import comfy.ldm.wan.model import comfy.ldm.hunyuan3d.model @@ -48,6 +49,7 @@ import comfy.ops from enum import Enum from . import utils import comfy.latent_formats +import comfy.model_sampling import math from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -63,38 +65,39 @@ class ModelType(Enum): V_PREDICTION_CONTINUOUS = 7 FLUX = 8 IMG_TO_IMG = 9 - - -from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV + FLOW_COSMOS = 10 def model_sampling(model_config, model_type): - s = ModelSamplingDiscrete + s = comfy.model_sampling.ModelSamplingDiscrete if model_type == ModelType.EPS: - c = EPS + c = comfy.model_sampling.EPS elif model_type == ModelType.V_PREDICTION: - c = V_PREDICTION + c = comfy.model_sampling.V_PREDICTION elif model_type == ModelType.V_PREDICTION_EDM: - c = V_PREDICTION - s = ModelSamplingContinuousEDM + c = comfy.model_sampling.V_PREDICTION + s = comfy.model_sampling.ModelSamplingContinuousEDM elif model_type == ModelType.FLOW: c = comfy.model_sampling.CONST s = comfy.model_sampling.ModelSamplingDiscreteFlow elif model_type == ModelType.STABLE_CASCADE: - c = EPS - s = StableCascadeSampling + c = comfy.model_sampling.EPS + s = comfy.model_sampling.StableCascadeSampling elif model_type == ModelType.EDM: - c = EDM - s = ModelSamplingContinuousEDM + c = comfy.model_sampling.EDM + s = comfy.model_sampling.ModelSamplingContinuousEDM elif model_type == ModelType.V_PREDICTION_CONTINUOUS: - c = V_PREDICTION - s = ModelSamplingContinuousV + c = comfy.model_sampling.V_PREDICTION + s = comfy.model_sampling.ModelSamplingContinuousV elif model_type == ModelType.FLUX: c = comfy.model_sampling.CONST s = comfy.model_sampling.ModelSamplingFlux elif model_type == ModelType.IMG_TO_IMG: c = comfy.model_sampling.IMG_TO_IMG + elif model_type == ModelType.FLOW_COSMOS: + c = comfy.model_sampling.COSMOS_RFLOW + s = comfy.model_sampling.ModelSamplingCosmosRFlow class ModelSampling(s, c): pass @@ -102,6 +105,13 @@ def model_sampling(model_config, model_type): return ModelSampling(model_config) +def convert_tensor(extra, dtype): + if hasattr(extra, "dtype"): + if extra.dtype != torch.int and extra.dtype != torch.long: + extra = extra.to(dtype) + return extra + + class BaseModel(torch.nn.Module): def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel): super().__init__() @@ -135,6 +145,7 @@ class BaseModel(torch.nn.Module): logging.info("model_type {}".format(model_type.name)) logging.debug("adm {}".format(self.adm_channels)) self.memory_usage_factor = model_config.memory_usage_factor + self.memory_usage_factor_conds = () def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -164,9 +175,14 @@ class BaseModel(torch.nn.Module): extra_conds = {} for o in kwargs: extra = kwargs[o] + if hasattr(extra, "dtype"): - if extra.dtype != torch.int and extra.dtype != torch.long: - extra = extra.to(dtype) + extra = convert_tensor(extra, dtype) + elif isinstance(extra, list): + ex = [] + for ext in extra: + ex.append(convert_tensor(ext, dtype)) + extra = ex extra_conds[o] = extra t = self.process_timestep(t, x=x, **extra_conds) @@ -325,19 +341,28 @@ class BaseModel(torch.nn.Module): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image) - def memory_required(self, input_shape): + def memory_required(self, input_shape, cond_shapes={}): + input_shapes = [input_shape] + for c in self.memory_usage_factor_conds: + shape = cond_shapes.get(c, None) + if shape is not None and len(shape) > 0: + input_shapes += shape + if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): dtype = self.get_dtype() if self.manual_cast_dtype is not None: dtype = self.manual_cast_dtype #TODO: this needs to be tweaked - area = input_shape[0] * math.prod(input_shape[2:]) + area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024) else: #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. - area = input_shape[0] * math.prod(input_shape[2:]) + area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024) + def extra_conds_shapes(self, **kwargs): + return {} + def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None): adm_inputs = [] @@ -924,6 +949,10 @@ class HunyuanVideo(BaseModel): if guiding_frame_index is not None: out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index])) + ref_latent = kwargs.get("ref_latent", None) + if ref_latent is not None: + out['ref_latent'] = comfy.conds.CONDRegular(self.process_latent_in(ref_latent)) + return out def scale_latent_inpaint(self, latent_image, **kwargs): @@ -972,6 +1001,43 @@ class CosmosVideo(BaseModel): latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image) return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5) +class CosmosPredict2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW_COSMOS, image_to_video=False, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.predict2.MiniTrainDIT) + self.image_to_video = image_to_video + if self.image_to_video: + self.concat_keys = ("mask_inverted",) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if denoise_mask is not None: + out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask) + + out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None)) + return out + + def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): + if denoise_mask is None: + return timestep + condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True) + c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1 + out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4]) + return out + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)) + sigma_noise_augmentation = 0 #TODO + if sigma_noise_augmentation != 0: + latent_image = latent_image + noise + latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image) + sigma = (sigma / (sigma + 1)) + return latent_image / (1.0 - sigma) + class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT) @@ -1043,6 +1109,11 @@ class WAN21(BaseModel): clip_vision_output = kwargs.get("clip_vision_output", None) if clip_vision_output is not None: out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states) + + time_dim_concat = kwargs.get("time_dim_concat", None) + if time_dim_concat is not None: + out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat)) + return out @@ -1058,23 +1129,39 @@ class WAN21_Vace(WAN21): vace_frames = kwargs.get("vace_frames", None) if vace_frames is None: noise_shape[1] = 32 - vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype) - - for i in range(0, vace_frames.shape[1], 16): - vace_frames = vace_frames.clone() - vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16]) + vace_frames = [torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)] mask = kwargs.get("vace_mask", None) if mask is None: noise_shape[1] = 64 - mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype) + mask = [torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)] * len(vace_frames) - out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1)) + vace_frames_out = [] + for j in range(len(vace_frames)): + vf = vace_frames[j].clone() + for i in range(0, vf.shape[1], 16): + vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16]) + vf = torch.cat([vf, mask[j]], dim=1) + vace_frames_out.append(vf) - vace_strength = kwargs.get("vace_strength", 1.0) + vace_frames = torch.stack(vace_frames_out, dim=1) + out['vace_context'] = comfy.conds.CONDRegular(vace_frames) + + vace_strength = kwargs.get("vace_strength", [1.0] * len(vace_frames_out)) out['vace_strength'] = comfy.conds.CONDConstant(vace_strength) return out +class WAN21_Camera(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel) + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + camera_conditions = kwargs.get("camera_conditions", None) + if camera_conditions is not None: + out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions) + return out class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 28c586389..4aa90d3b6 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -361,6 +361,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "vace" dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1] dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.') + elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "camera" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" @@ -405,6 +407,58 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["text_emb_dim"] = 2048 return dit_config + if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2 + dit_config = {} + dit_config["image_model"] = "cosmos_predict2" + dit_config["max_img_h"] = 240 + dit_config["max_img_w"] = 240 + dit_config["max_frames"] = 128 + concat_padding_mask = True + dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask) + dit_config["out_channels"] = 16 + dit_config["patch_spatial"] = 2 + dit_config["patch_temporal"] = 1 + dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0] + dit_config["concat_padding_mask"] = concat_padding_mask + dit_config["crossattn_emb_channels"] = 1024 + dit_config["pos_emb_cls"] = "rope3d" + dit_config["pos_emb_learnable"] = True + dit_config["pos_emb_interpolation"] = "crop" + dit_config["min_fps"] = 1 + dit_config["max_fps"] = 30 + + dit_config["use_adaln_lora"] = True + dit_config["adaln_lora_dim"] = 256 + if dit_config["model_channels"] == 2048: + dit_config["num_blocks"] = 28 + dit_config["num_heads"] = 16 + elif dit_config["model_channels"] == 5120: + dit_config["num_blocks"] = 36 + dit_config["num_heads"] = 40 + + if dit_config["in_channels"] == 16: + dit_config["extra_per_block_abs_pos_emb"] = False + dit_config["rope_h_extrapolation_ratio"] = 4.0 + dit_config["rope_w_extrapolation_ratio"] = 4.0 + dit_config["rope_t_extrapolation_ratio"] = 1.0 + elif dit_config["in_channels"] == 17: # img to video + if dit_config["model_channels"] == 2048: + dit_config["extra_per_block_abs_pos_emb"] = False + dit_config["rope_h_extrapolation_ratio"] = 3.0 + dit_config["rope_w_extrapolation_ratio"] = 3.0 + dit_config["rope_t_extrapolation_ratio"] = 1.0 + elif dit_config["model_channels"] == 5120: + dit_config["rope_h_extrapolation_ratio"] = 2.0 + dit_config["rope_w_extrapolation_ratio"] = 2.0 + dit_config["rope_t_extrapolation_ratio"] = 0.8333333333333334 + + dit_config["extra_h_extrapolation_ratio"] = 1.0 + dit_config["extra_w_extrapolation_ratio"] = 1.0 + dit_config["extra_t_extrapolation_ratio"] = 1.0 + dit_config["rope_enable_fps_modulation"] = False + + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -618,6 +672,9 @@ def convert_config(unet_config): def unet_config_from_diffusers_unet(state_dict, dtype=None): + if "conv_in.weight" not in state_dict: + return None + match = {} transformer_depth = [] diff --git a/comfy/model_management.py b/comfy/model_management.py index c2b41ce0c..6a61c4d9c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -319,14 +319,24 @@ except: pass +SUPPORT_FP8_OPS = args.supports_fp8_compute try: if is_amd(): + try: + rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) + except: + rocm_version = (6, -1) arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName logging.info("AMD arch: {}".format(arch)) + logging.info("ROCm version: {}".format(rocm_version)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much - if any((a in arch) for a in ["gfx1100", "gfx1101"]): # TODO: more arches + if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950 ENABLE_PYTORCH_ATTENTION = True + if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): + if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches + SUPPORT_FP8_OPS = True + except: pass @@ -347,7 +357,7 @@ except: pass try: - if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5: + if torch_version_numeric >= (2, 5): torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) except: logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp") @@ -723,7 +733,7 @@ def unet_inital_load_device(parameters, dtype): return torch_dev cpu_dev = torch.device("cpu") - if DISABLE_SMART_MEMORY: + if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM: return cpu_dev model_size = dtype_size(dtype) * parameters @@ -1070,7 +1080,7 @@ def pytorch_attention_flash_attention(): global ENABLE_PYTORCH_ATTENTION if ENABLE_PYTORCH_ATTENTION: #TODO: more reliable way of checking for flash attention? - if is_nvidia(): #pytorch flash attention only works on Nvidia + if is_nvidia(): return True if is_intel_xpu(): return True @@ -1086,7 +1096,7 @@ def force_upcast_attention_dtype(): upcast = args.force_upcast_attention macos_version = mac_version() - if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS + if macos_version is not None and ((14, 5) <= macos_version): # black image bug on recent versions of macOS, I don't think it's ever getting fixed upcast = True if upcast: @@ -1285,6 +1295,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False def supports_fp8_compute(device=None): + if SUPPORT_FP8_OPS: + return True + if not is_nvidia(): return False @@ -1296,11 +1309,11 @@ def supports_fp8_compute(device=None): if props.minor < 9: return False - if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3): + if torch_version_numeric < (2, 3): return False if WINDOWS: - if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4): + if torch_version_numeric < (2, 4): return False return True diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index deb8af327..7cc62c4ed 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -17,23 +17,26 @@ """ from __future__ import annotations -from typing import Optional, Callable -import torch + +import collections import copy import inspect import logging -import uuid -import collections import math +import uuid +from typing import Callable, Optional + +import torch -import comfy.utils import comfy.float -import comfy.model_management -import comfy.lora import comfy.hooks +import comfy.lora +import comfy.model_management import comfy.patcher_extension -from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection +import comfy.utils from comfy.comfy_types import UnetWrapperFunction +from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP + def string_to_seed(data): crc = 0xFFFFFFFF diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 7e7291476..b240b7f29 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -77,6 +77,25 @@ class IMG_TO_IMG(X0): def calculate_input(self, sigma, noise): return noise +class COSMOS_RFLOW: + def calculate_input(self, sigma, noise): + sigma = (sigma / (sigma + 1)) + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + return noise * (1.0 - sigma) + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = (sigma / (sigma + 1)) + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input * (1.0 - sigma) - model_output * sigma + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + noise = noise * sigma + noise += latent_image + return noise + + def inverse_noise_scaling(self, sigma, latent): + return latent class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None, zsnr=None): @@ -350,3 +369,15 @@ class ModelSamplingFlux(torch.nn.Module): if percent >= 1.0: return 0.0 return flux_time_shift(self.shift, 1.0, 1.0 - percent) + + +class ModelSamplingCosmosRFlow(ModelSamplingContinuousEDM): + def timestep(self, sigma): + return sigma / (sigma + 1) + + def sigma(self, timestep): + sigma_max = self.sigma_max + if timestep >= (sigma_max / (sigma_max + 1)): + return sigma_max + + return timestep / (1 - timestep) diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index 9d82bee1a..66ae8321d 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -30,7 +30,7 @@ if RMSNorm is None: def __init__( self, normalized_shape, - eps=None, + eps=1e-6, elementwise_affine=True, device=None, dtype=None, diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index f2cee7874..ccb3e39a5 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -1,6 +1,8 @@ from __future__ import annotations import torch import uuid +import math +import collections import comfy.model_management import comfy.conds import comfy.model_patcher @@ -147,6 +149,22 @@ def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPat curr_cnet = prev_cnet # potentially handle gligen - since not widely used, ignored for now +def estimate_memory(model, noise_shape, conds): + cond_shapes = collections.defaultdict(list) + cond_shapes_min = {} + for _, cs in conds.items(): + for cond in cs: + for k, v in model.model.extra_conds_shapes(**cond).items(): + cond_shapes[k].append(v) + if cond_shapes_min.get(k, None) is None: + cond_shapes_min[k] = [v] + elif math.prod(v) > math.prod(cond_shapes_min[k][0]): + cond_shapes_min[k] = [v] + + memory_required = model.model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:]), cond_shapes=cond_shapes) + minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min) + return memory_required, minimum_memory_required + def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): executor = comfy.patcher_extension.WrapperExecutor.new_executor( _prepare_sampling, @@ -160,10 +178,9 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? - memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory - minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory - comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required) - real_model: BaseModel = model.model + memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) + comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory) + real_model = model.model return real_model, conds, models diff --git a/comfy/samplers.py b/comfy/samplers.py index 90cce078d..7db6a68b4 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -263,7 +263,13 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) * 1.5 < free_memory: + cond_shapes = collections.defaultdict(list) + for tt in batch_amount: + cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()} + for k, v in to_run[tt][0].conditioning.items(): + cond_shapes[k].append(v.size()) + + if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory: to_batch = batch_amount break diff --git a/comfy/sd.py b/comfy/sd.py index e98a3aa87..cd13ab5f0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1081,7 +1081,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format +def load_diffusion_model_state_dict(sd, model_options={}): + """ + Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. + + Args: + sd (dict): State dictionary containing model weights and configuration + model_options (dict, optional): Additional options for model loading. Supports: + - dtype: Override model data type + - custom_operations: Custom model operations + - fp8_optimizations: Enable FP8 optimizations + + Returns: + ModelPatcher: A wrapped model instance that handles device management and weight loading. + Returns None if the model configuration cannot be detected. + + The function: + 1. Detects and handles different model formats (regular, diffusers, mmdit) + 2. Configures model dtype based on parameters and device capabilities + 3. Handles weight conversion and device placement + 4. Manages model optimization settings + 5. Loads weights and returns a device-managed model instance + """ dtype = model_options.get("dtype", None) #Allow loading unets from checkpoint files diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index ac61babe9..1b69a4103 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -462,7 +462,7 @@ class SDTokenizer: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length) - self.min_length = min_length + self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length) self.end_token = None self.min_padding = min_padding diff --git a/comfy/supported_models.py b/comfy/supported_models.py index fef25eb24..19f25e337 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -908,6 +908,48 @@ class CosmosI2V(CosmosT2V): out = model_base.CosmosVideo(self, image_to_video=True, device=device) return out +class CosmosT2IPredict2(supported_models_base.BASE): + unet_config = { + "image_model": "cosmos_predict2", + "in_channels": 16, + } + + sampling_settings = { + "sigma_data": 1.0, + "sigma_max": 80.0, + "sigma_min": 0.002, + } + + unet_extra_config = {} + latent_format = latent_formats.Wan21 + + memory_usage_factor = 1.0 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.CosmosPredict2(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect)) + +class CosmosI2VPredict2(CosmosT2IPredict2): + unet_config = { + "image_model": "cosmos_predict2", + "in_channels": 17, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.CosmosPredict2(self, image_to_video=True, device=device) + return out + class Lumina2(supported_models_base.BASE): unet_config = { "image_model": "lumina2", @@ -992,6 +1034,16 @@ class WAN21_FunControl2V(WAN21_T2V): out = model_base.WAN21(self, image_to_video=False, device=device) return out +class WAN21_Camera(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "camera", + "in_dim": 32, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_Camera(self, image_to_video=False, device=device) + return out class WAN21_Vace(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1129,6 +1181,6 @@ class ACEStep(supported_models_base.BASE): def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep] models += [SVD_img2vid] diff --git a/comfy/text_encoders/long_clipl.json b/comfy/text_encoders/long_clipl.json deleted file mode 100644 index 5e2056ff3..000000000 --- a/comfy/text_encoders/long_clipl.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "_name_or_path": "openai/clip-vit-large-patch14", - "architectures": [ - "CLIPTextModel" - ], - "attention_dropout": 0.0, - "bos_token_id": 0, - "dropout": 0.0, - "eos_token_id": 49407, - "hidden_act": "quick_gelu", - "hidden_size": 768, - "initializer_factor": 1.0, - "initializer_range": 0.02, - "intermediate_size": 3072, - "layer_norm_eps": 1e-05, - "max_position_embeddings": 248, - "model_type": "clip_text_model", - "num_attention_heads": 12, - "num_hidden_layers": 12, - "pad_token_id": 1, - "projection_dim": 768, - "torch_dtype": "float32", - "transformers_version": "4.24.0", - "vocab_size": 49408 -} diff --git a/comfy/utils.py b/comfy/utils.py index 561e1b858..1f8d71292 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -78,8 +78,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) else: pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) - if "global_step" in pl_sd: - logging.debug(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py index d2a1d0151..560b82be3 100644 --- a/comfy/weight_adapter/__init__.py +++ b/comfy/weight_adapter/__init__.py @@ -1,4 +1,4 @@ -from .base import WeightAdapterBase +from .base import WeightAdapterBase, WeightAdapterTrainBase from .lora import LoRAAdapter from .loha import LoHaAdapter from .lokr import LoKrAdapter @@ -15,3 +15,9 @@ adapters: list[type[WeightAdapterBase]] = [ OFTAdapter, BOFTAdapter, ] + +__all__ = [ + "WeightAdapterBase", + "WeightAdapterTrainBase", + "adapters" +] + [a.__name__ for a in adapters] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index 29873519d..b5c7db423 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -12,12 +12,20 @@ class WeightAdapterBase: weights: list[torch.Tensor] @classmethod - def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]: + def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]: raise NotImplementedError def to_train(self) -> "WeightAdapterTrainBase": raise NotImplementedError + @classmethod + def create_train(cls, weight, *args) -> "WeightAdapterTrainBase": + """ + weight: The original weight tensor to be modified. + *args: Additional arguments for configuration, such as rank, alpha etc. + """ + raise NotImplementedError + def calculate_weight( self, weight, @@ -33,10 +41,22 @@ class WeightAdapterBase: class WeightAdapterTrainBase(nn.Module): + # We follow the scheme of PR #7032 def __init__(self): super().__init__() - # [TODO] Collaborate with LoRA training PR #7032 + def __call__(self, w): + """ + w: The original weight tensor to be modified. + """ + raise NotImplementedError + + def passive_memory_usage(self): + raise NotImplementedError("passive_memory_usage is not implemented") + + def move_to(self, device): + self.to(device) + return self.passive_memory_usage() def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): @@ -102,3 +122,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten padded_tensor[new_slices] = tensor[orig_slices] return padded_tensor + + +def tucker_weight_from_conv(up, down, mid): + up = up.reshape(up.size(0), up.size(1)) + down = down.reshape(down.size(0), down.size(1)) + return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down) + + +def tucker_weight(wa, wb, t): + temp = torch.einsum("i j ..., j r -> i r ...", t, wb) + return torch.einsum("i j ..., i r -> r j ...", temp, wa) diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index b2e623924..729dbd9e6 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -3,7 +3,56 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape +from .base import ( + WeightAdapterBase, + WeightAdapterTrainBase, + weight_decompose, + pad_tensor_to_shape, + tucker_weight_from_conv, +) + + +class LoraDiff(WeightAdapterTrainBase): + def __init__(self, weights): + super().__init__() + mat1, mat2, alpha, mid, dora_scale, reshape = weights + out_dim, rank = mat1.shape[0], mat1.shape[1] + rank, in_dim = mat2.shape[0], mat2.shape[1] + if mid is not None: + convdim = mid.ndim - 2 + layer = ( + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d + )[convdim] + else: + layer = torch.nn.Linear + self.lora_up = layer(rank, out_dim, bias=False) + self.lora_down = layer(in_dim, rank, bias=False) + self.lora_up.weight.data.copy_(mat1) + self.lora_down.weight.data.copy_(mat2) + if mid is not None: + self.lora_mid = layer(mid, rank, bias=False) + self.lora_mid.weight.data.copy_(mid) + else: + self.lora_mid = None + self.rank = rank + self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) + + def __call__(self, w): + org_dtype = w.dtype + if self.lora_mid is None: + diff = self.lora_up.weight @ self.lora_down.weight + else: + diff = tucker_weight_from_conv( + self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight + ) + scale = self.alpha / self.rank + weight = w + scale * diff.reshape(w.shape) + return weight.to(org_dtype) + + def passive_memory_usage(self): + return sum(param.numel() * param.element_size() for param in self.parameters()) class LoRAAdapter(WeightAdapterBase): @@ -13,6 +62,21 @@ class LoRAAdapter(WeightAdapterBase): self.loaded_keys = loaded_keys self.weights = weights + @classmethod + def create_train(cls, weight, rank=1, alpha=1.0): + out_dim = weight.shape[0] + in_dim = weight.shape[1:].numel() + mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) + mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) + torch.nn.init.constant_(mat2, 0.0) + return LoraDiff( + (mat1, mat2, alpha, None, None, None) + ) + + def to_train(self): + return LoraDiff(self.weights) + @classmethod def load( cls, diff --git a/comfy_api/input/video_types.py b/comfy_api/input/video_types.py index 0676e0e66..dc22d34ff 100644 --- a/comfy_api/input/video_types.py +++ b/comfy_api/input/video_types.py @@ -43,3 +43,13 @@ class VideoInput(ABC): components = self.get_components() return components.images.shape[2], components.images.shape[1] + def get_duration(self) -> float: + """ + Returns the duration of the video in seconds. + + Returns: + Duration in seconds + """ + components = self.get_components() + frame_count = components.images.shape[0] + return float(frame_count / components.frame_rate) diff --git a/comfy_api/input_impl/video_types.py b/comfy_api/input_impl/video_types.py index ae48dbaa4..197f6558c 100644 --- a/comfy_api/input_impl/video_types.py +++ b/comfy_api/input_impl/video_types.py @@ -80,6 +80,38 @@ class VideoFromFile(VideoInput): return stream.width, stream.height raise ValueError(f"No video stream found in file '{self.__file}'") + def get_duration(self) -> float: + """ + Returns the duration of the video in seconds. + + Returns: + Duration in seconds + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + with av.open(self.__file, mode="r") as container: + if container.duration is not None: + return float(container.duration / av.time_base) + + # Fallback: calculate from frame count and frame rate + video_stream = next( + (s for s in container.streams if s.type == "video"), None + ) + if video_stream and video_stream.frames and video_stream.average_rate: + return float(video_stream.frames / video_stream.average_rate) + + # Last resort: decode frames to count them + if video_stream and video_stream.average_rate: + frame_count = 0 + container.seek(0) + for packet in container.demux(video_stream): + for _ in packet.decode(): + frame_count += 1 + if frame_count > 0: + return float(frame_count / video_stream.average_rate) + + raise ValueError(f"Could not determine duration for file '{self.__file}'") + def get_components_internal(self, container: InputContainer) -> VideoComponents: # Get video frames frames = [] diff --git a/comfy_api/torch_helpers/__init__.py b/comfy_api/torch_helpers/__init__.py new file mode 100644 index 000000000..be7ae7a61 --- /dev/null +++ b/comfy_api/torch_helpers/__init__.py @@ -0,0 +1,5 @@ +from .torch_compile import set_torch_compile_wrapper + +__all__ = [ + "set_torch_compile_wrapper", +] diff --git a/comfy_api/torch_helpers/torch_compile.py b/comfy_api/torch_helpers/torch_compile.py new file mode 100644 index 000000000..9223f58db --- /dev/null +++ b/comfy_api/torch_helpers/torch_compile.py @@ -0,0 +1,69 @@ +from __future__ import annotations +import torch + +import comfy.utils +from comfy.patcher_extension import WrappersMP +from typing import TYPE_CHECKING, Callable, Optional +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + from comfy.patcher_extension import WrapperExecutor + + +COMPILE_KEY = "torch.compile" +TORCH_COMPILE_KWARGS = "torch_compile_kwargs" + + +def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable: + ''' + Create a wrapper that will refer to the compiled_diffusion_model. + ''' + def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs): + try: + orig_modules = {} + for key, value in compiled_module_dict.items(): + orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key) + comfy.utils.set_attr(executor.class_obj, key, value) + return executor(*args, **kwargs) + finally: + for key, value in orig_modules.items(): + comfy.utils.set_attr(executor.class_obj, key, value) + return apply_torch_compile_wrapper + + +def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None, + mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None, + keys: list[str]=["diffusion_model"], *args, **kwargs): + ''' + Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance. + + When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model. + When a list of keys is provided, it will perform torch.compile on only the selected modules. + ''' + # clear out any other torch.compile wrappers + model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY) + # if no keys, default to 'diffusion_model' + if not keys: + keys = ["diffusion_model"] + # create kwargs dict that can be referenced later + compile_kwargs = { + "backend": backend, + "options": options, + "mode": mode, + "fullgraph": fullgraph, + "dynamic": dynamic, + } + # get a dict of compiled keys + compiled_modules = {} + for key in keys: + compiled_modules[key] = torch.compile( + model=model.get_model_object(key), + **compile_kwargs, + ) + # add torch.compile wrapper + wrapper_func = apply_torch_compile_factory( + compiled_module_dict=compiled_modules, + ) + # store wrapper to run on BaseModel's apply_model function + model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func) + # keep compile kwargs for reference + model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs diff --git a/comfy_api_nodes/README.md b/comfy_api_nodes/README.md index e2633a769..64a389cc1 100644 --- a/comfy_api_nodes/README.md +++ b/comfy_api_nodes/README.md @@ -18,6 +18,8 @@ Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to python run main.py --comfy-api-base https://stagingapi.comfy.org ``` +To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging. + API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes. ### Redocly Instructions @@ -28,7 +30,7 @@ When developing locally, use the `redocly-dev.yaml` file to generate pydantic mo Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging. ```bash -# Download the OpenAPI file from prod server. +# Download the OpenAPI file from staging server. curl -o openapi.yaml https://stagingapi.comfy.org/openapi # Filter out unneeded API definitions. @@ -39,3 +41,25 @@ redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_no datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel ``` + + +# Merging to Master + +Before merging to comfyanonymous/ComfyUI master, follow these steps: + +1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes. +1. Make sure the ComfyUI API is deployed to prod with your changes. +1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file. + +```bash +# Download the OpenAPI file from prod server. +curl -o openapi.yaml https://api.comfy.org/openapi + +# Filter out unneeded API definitions. +npm install -g @redocly/cli +redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components + +# Generate the pydantic datamodels for validation. +datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel + +``` diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index e28d7d607..788e2803f 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -1,7 +1,8 @@ from __future__ import annotations import io import logging -from typing import Optional +import mimetypes +from typing import Optional, Union from comfy.utils import common_upscale from comfy_api.input_impl import VideoFromFile from comfy_api.util import VideoContainer, VideoCodec @@ -15,6 +16,7 @@ from comfy_api_nodes.apis.client import ( UploadRequest, UploadResponse, ) +from server import PromptServer import numpy as np @@ -60,7 +62,9 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: return s -def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor: +def validate_and_cast_response( + response, timeout: int = None, node_id: Union[str, None] = None +) -> torch.Tensor: """Validates and casts a response to a torch.Tensor. Args: @@ -94,6 +98,10 @@ def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor: img = Image.open(io.BytesIO(img_data)) elif image_url: + if node_id: + PromptServer.instance.send_progress_text( + f"Result URL: {image_url}", node_id + ) img_response = requests.get(image_url, timeout=timeout) if img_response.status_code != 200: raise ValueError("Failed to download the image") @@ -207,6 +215,7 @@ def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor: image_bytesio = download_url_to_bytesio(url, timeout) return bytesio_to_image_tensor(image_bytesio) + def process_image_response(response: requests.Response) -> torch.Tensor: """Uses content from a Response object and converts it to a torch.Tensor""" return bytesio_to_image_tensor(BytesIO(response.content)) @@ -311,11 +320,27 @@ def tensor_to_data_uri( return f"data:{mime_type};base64,{base64_string}" +def text_filepath_to_base64_string(filepath: str) -> str: + """Converts a text file to a base64 string.""" + with open(filepath, "rb") as f: + file_content = f.read() + return base64.b64encode(file_content).decode("utf-8") + + +def text_filepath_to_data_uri(filepath: str) -> str: + """Converts a text file to a data URI.""" + base64_string = text_filepath_to_base64_string(filepath) + mime_type, _ = mimetypes.guess_type(filepath) + if mime_type is None: + mime_type = "application/octet-stream" + return f"data:{mime_type};base64,{base64_string}" + + def upload_file_to_comfyapi( file_bytes_io: BytesIO, filename: str, upload_mime_type: str, - auth_kwargs: Optional[dict[str,str]] = None, + auth_kwargs: Optional[dict[str, str]] = None, ) -> str: """ Uploads a single file to ComfyUI API and returns its download URL. @@ -350,9 +375,33 @@ def upload_file_to_comfyapi( return response.download_url +def video_to_base64_string( + video: VideoInput, + container_format: VideoContainer = None, + codec: VideoCodec = None +) -> str: + """ + Converts a video input to a base64 string. + + Args: + video: The video input to convert + container_format: Optional container format to use (defaults to video.container if available) + codec: Optional codec to use (defaults to video.codec if available) + """ + video_bytes_io = io.BytesIO() + + # Use provided format/codec if specified, otherwise use video's own if available + format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) + codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264) + + video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use) + video_bytes_io.seek(0) + return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") + + def upload_video_to_comfyapi( video: VideoInput, - auth_kwargs: Optional[dict[str,str]] = None, + auth_kwargs: Optional[dict[str, str]] = None, container: VideoContainer = VideoContainer.MP4, codec: VideoCodec = VideoCodec.H264, max_duration: Optional[int] = None, @@ -454,7 +503,7 @@ def audio_ndarray_to_bytesio( def upload_audio_to_comfyapi( audio: AudioInput, - auth_kwargs: Optional[dict[str,str]] = None, + auth_kwargs: Optional[dict[str, str]] = None, container_format: str = "mp4", codec_name: str = "aac", mime_type: str = "audio/mp4", @@ -481,8 +530,25 @@ def upload_audio_to_comfyapi( return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs) +def audio_to_base64_string( + audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac" +) -> str: + """Converts an audio input to a base64 string.""" + sample_rate: int = audio["sample_rate"] + waveform: torch.Tensor = audio["waveform"] + audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) + audio_bytes_io = audio_ndarray_to_bytesio( + audio_data_np, sample_rate, container_format, codec_name + ) + audio_bytes = audio_bytes_io.getvalue() + return base64.b64encode(audio_bytes).decode("utf-8") + + def upload_images_to_comfyapi( - image: torch.Tensor, max_images=8, auth_kwargs: Optional[dict[str,str]] = None, mime_type: Optional[str] = None + image: torch.Tensor, + max_images=8, + auth_kwargs: Optional[dict[str, str]] = None, + mime_type: Optional[str] = None, ) -> list[str]: """ Uploads images to ComfyUI API and returns download URLs. @@ -547,17 +613,24 @@ def upload_images_to_comfyapi( return download_urls -def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor, - upscale_method="nearest-exact", crop="disabled", - allow_gradient=True, add_channel_dim=False): +def resize_mask_to_image( + mask: torch.Tensor, + image: torch.Tensor, + upscale_method="nearest-exact", + crop="disabled", + allow_gradient=True, + add_channel_dim=False, +): """ Resize mask to be the same dimensions as an image, while maintaining proper format for API calls. """ _, H, W, _ = image.shape mask = mask.unsqueeze(-1) - mask = mask.movedim(-1,1) - mask = common_upscale(mask, width=W, height=H, upscale_method=upscale_method, crop=crop) - mask = mask.movedim(1,-1) + mask = mask.movedim(-1, 1) + mask = common_upscale( + mask, width=W, height=H, upscale_method=upscale_method, crop=crop + ) + mask = mask.movedim(1, -1) if not add_channel_dim: mask = mask.squeeze(-1) if not allow_gradient: @@ -565,12 +638,41 @@ def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor, return mask -def validate_string(string: str, strip_whitespace=True, field_name="prompt", min_length=None, max_length=None): +def validate_string( + string: str, + strip_whitespace=True, + field_name="prompt", + min_length=None, + max_length=None, +): + if string is None: + raise Exception(f"Field '{field_name}' cannot be empty.") if strip_whitespace: string = string.strip() if min_length and len(string) < min_length: - raise Exception(f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long.") + raise Exception( + f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long." + ) if max_length and len(string) > max_length: - raise Exception(f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long.") - if not string: - raise Exception(f"Field '{field_name}' cannot be empty.") + raise Exception( + f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long." + ) + + +def image_tensor_pair_to_batch( + image1: torch.Tensor, image2: torch.Tensor +) -> torch.Tensor: + """ + Converts a pair of image tensors to a batch tensor. + If the images are not the same size, the smaller image is resized to + match the larger image. + """ + if image1.shape[1:] != image2.shape[1:]: + image2 = common_upscale( + image2.movedim(-1, 1), + image1.shape[2], + image1.shape[1], + "bilinear", + "center", + ).movedim(1, -1) + return torch.cat((image1, image2), dim=0) diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index aa1c4ce0b..e38d38cc9 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -1,67 +1,197 @@ # generated by datamodel-codegen: # filename: filtered-openapi.yaml -# timestamp: 2025-05-04T04:12:39+00:00 +# timestamp: 2025-05-19T21:38:55+00:00 from __future__ import annotations -from datetime import datetime +from datetime import date, datetime from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union from uuid import UUID -from pydantic import AnyUrl, BaseModel, Field, RootModel, StrictBytes +from pydantic import AnyUrl, BaseModel, ConfigDict, Field, RootModel, StrictBytes -class PersonalAccessToken(BaseModel): - id: Optional[UUID] = Field(None, description='Unique identifier for the GitCommit') - name: Optional[str] = Field( - None, - description='Required. The name of the token. Can be a simple description.', - ) - description: Optional[str] = Field( - None, - description="Optional. A more detailed description of the token's intended use.", +class APIKey(BaseModel): + created_at: Optional[datetime] = None + description: Optional[str] = None + id: Optional[str] = None + key_prefix: Optional[str] = None + name: Optional[str] = None + + +class APIKeyWithPlaintext(APIKey): + plaintext_key: Optional[str] = Field( + None, description='The full API key (only returned at creation)' ) + + +class AuditLog(BaseModel): createdAt: Optional[datetime] = Field( - None, description='[Output Only]The date and time the token was created.' + None, description='The date and time the event was created' ) - token: Optional[str] = Field( + event_id: Optional[str] = Field(None, description='the id of the event') + event_type: Optional[str] = Field(None, description='the type of the event') + params: Optional[Dict[str, Any]] = Field( + None, description='data related to the event' + ) + + +class OutputFormat(str, Enum): + jpeg = 'jpeg' + png = 'png' + + +class BFLFluxPro11GenerateRequest(BaseModel): + height: int = Field(..., description='Height of the generated image') + image_prompt: Optional[str] = Field(None, description='Optional image prompt') + output_format: Optional[OutputFormat] = Field( + None, description='Output image format' + ) + prompt: str = Field(..., description='The main text prompt for image generation') + prompt_upsampling: Optional[bool] = Field( + None, description='Whether to use prompt upsampling' + ) + safety_tolerance: Optional[int] = Field(None, description='Safety tolerance level') + seed: Optional[int] = Field(None, description='Random seed for reproducibility') + webhook_secret: Optional[str] = Field( + None, description='Optional webhook secret for async processing' + ) + webhook_url: Optional[str] = Field( + None, description='Optional webhook URL for async processing' + ) + width: int = Field(..., description='Width of the generated image') + + +class BFLFluxPro11GenerateResponse(BaseModel): + id: str = Field(..., description='Job ID for tracking') + polling_url: str = Field(..., description='URL to poll for results') + + +class BFLFluxProGenerateRequest(BaseModel): + guidance_scale: Optional[float] = Field( + None, description='The guidance scale for generation.', ge=1.0, le=20.0 + ) + height: int = Field( + ..., description='The height of the image to generate.', ge=64, le=2048 + ) + negative_prompt: Optional[str] = Field( + None, description='The negative prompt for image generation.' + ) + num_images: Optional[int] = Field( + None, description='The number of images to generate.', ge=1, le=4 + ) + num_inference_steps: Optional[int] = Field( + None, description='The number of inference steps.', ge=1, le=100 + ) + prompt: str = Field(..., description='The text prompt for image generation.') + seed: Optional[int] = Field(None, description='The seed value for reproducibility.') + width: int = Field( + ..., description='The width of the image to generate.', ge=64, le=2048 + ) + + +class BFLFluxProGenerateResponse(BaseModel): + id: str = Field(..., description='The unique identifier for the generation task.') + polling_url: str = Field(..., description='URL to poll for the generation result.') + + +class Status(str, Enum): + in_progress = 'in_progress' + completed = 'completed' + incomplete = 'incomplete' + + +class Type(str, Enum): + computer_call = 'computer_call' + + +class ComputerToolCall(BaseModel): + action: Dict[str, Any] + call_id: str = Field( + ..., + description='An identifier used when responding to the tool call with output.\n', + ) + id: str = Field(..., description='The unique ID of the computer call.') + status: Status = Field( + ..., + description='The status of the item. One of `in_progress`, `completed`, or\n`incomplete`. Populated when items are returned via API.\n', + ) + type: Type = Field( + ..., description='The type of the computer call. Always `computer_call`.' + ) + + +class Environment(str, Enum): + windows = 'windows' + mac = 'mac' + linux = 'linux' + ubuntu = 'ubuntu' + browser = 'browser' + + +class Type1(str, Enum): + computer_use_preview = 'computer_use_preview' + + +class ComputerUsePreviewTool(BaseModel): + display_height: int = Field(..., description='The height of the computer display.') + display_width: int = Field(..., description='The width of the computer display.') + environment: Environment = Field( + ..., description='The type of computer environment to control.' + ) + type: Literal['ComputerUsePreviewTool'] = Field( + ..., + description='The type of the computer use tool. Always `computer_use_preview`.', + ) + + +class CreateAPIKeyRequest(BaseModel): + description: Optional[str] = None + name: str + + +class Customer(BaseModel): + createdAt: Optional[datetime] = Field( + None, description='The date and time the user was created' + ) + email: Optional[str] = Field(None, description='The email address for this user') + id: str = Field(..., description='The firebase UID of the user') + is_admin: Optional[bool] = Field(None, description='Whether the user is an admin') + metronome_id: Optional[str] = Field(None, description='The Metronome customer ID') + name: Optional[str] = Field(None, description='The name for this user') + stripe_id: Optional[str] = Field(None, description='The Stripe customer ID') + updatedAt: Optional[datetime] = Field( + None, description='The date and time the user was last updated' + ) + + +class CustomerStorageResourceResponse(BaseModel): + download_url: Optional[str] = Field( None, - description='[Output Only]. The personal access token. Only returned during creation.', + description='The signed URL to use for downloading the file from the specified path', + ) + existing_file: Optional[bool] = Field( + None, description='Whether an existing file with the same hash was found' + ) + expires_at: Optional[datetime] = Field( + None, description='When the signed URL will expire' + ) + upload_url: Optional[str] = Field( + None, + description='The signed URL to use for uploading the file to the specified path', ) -class GitCommitSummary(BaseModel): - commit_hash: Optional[str] = Field(None, description='The hash of the commit') - commit_name: Optional[str] = Field(None, description='The name of the commit') - branch_name: Optional[str] = Field( - None, description='The branch where the commit was made' - ) - author: Optional[str] = Field(None, description='The author of the commit') - timestamp: Optional[datetime] = Field( - None, description='The timestamp when the commit was made' - ) - status_summary: Optional[Dict[str, str]] = Field( - None, description='A map of operating system to status pairs' - ) +class Role(str, Enum): + user = 'user' + assistant = 'assistant' + system = 'system' + developer = 'developer' -class User(BaseModel): - id: Optional[str] = Field(None, description='The unique id for this user.') - email: Optional[str] = Field(None, description='The email address for this user.') - name: Optional[str] = Field(None, description='The name for this user.') - isApproved: Optional[bool] = Field( - None, description='Indicates if the user is approved.' - ) - isAdmin: Optional[bool] = Field( - None, description='Indicates if the user has admin privileges.' - ) - - -class PublisherUser(BaseModel): - id: Optional[str] = Field(None, description='The unique id for this user.') - email: Optional[str] = Field(None, description='The email address for this user.') - name: Optional[str] = Field(None, description='The name for this user.') +class Type2(str, Enum): + message = 'message' class ErrorResponse(BaseModel): @@ -69,168 +199,247 @@ class ErrorResponse(BaseModel): message: str -class StorageFile(BaseModel): - id: Optional[UUID] = Field( - None, description='Unique identifier for the storage file' - ) - file_path: Optional[str] = Field(None, description='Path to the file in storage') - public_url: Optional[str] = Field(None, description='Public URL') +class Type3(str, Enum): + file_search = 'file_search' -class PublisherMember(BaseModel): - id: Optional[str] = Field( - None, description='The unique identifier for the publisher member.' - ) - user: Optional[PublisherUser] = Field( - None, description='The user associated with this publisher member.' - ) - role: Optional[str] = Field( - None, description='The role of the user in the publisher.' +class FileSearchTool(BaseModel): + type: Literal['FileSearchTool'] = Field(..., description='The type of tool') + vector_store_ids: List[str] = Field( + ..., description='IDs of vector stores to search in' ) -class ComfyNode(BaseModel): - comfy_node_name: Optional[str] = Field( - None, description='Unique identifier for the node' +class Result(BaseModel): + file_id: Optional[str] = Field(None, description='The unique ID of the file.\n') + filename: Optional[str] = Field(None, description='The name of the file.\n') + score: Optional[float] = Field( + None, description='The relevance score of the file - a value between 0 and 1.\n' ) - category: Optional[str] = Field( - None, - description='UI category where the node is listed, used for grouping nodes.', + text: Optional[str] = Field( + None, description='The text that was retrieved from the file.\n' ) + + +class Status1(str, Enum): + in_progress = 'in_progress' + searching = 'searching' + completed = 'completed' + incomplete = 'incomplete' + failed = 'failed' + + +class Type4(str, Enum): + file_search_call = 'file_search_call' + + +class FileSearchToolCall(BaseModel): + id: str = Field(..., description='The unique ID of the file search tool call.\n') + queries: List[str] = Field( + ..., description='The queries used to search for files.\n' + ) + results: Optional[List[Result]] = Field( + None, description='The results of the file search tool call.\n' + ) + status: Status1 = Field( + ..., + description='The status of the file search tool call. One of `in_progress`, \n`searching`, `incomplete` or `failed`,\n', + ) + type: Type4 = Field( + ..., + description='The type of the file search tool call. Always `file_search_call`.\n', + ) + + +class Type5(str, Enum): + function = 'function' + + +class FunctionTool(BaseModel): description: Optional[str] = Field( - None, description="Brief description of the node's functionality or purpose." + None, description='Description of what the function does' ) - input_types: Optional[str] = Field(None, description='Defines input parameters') - deprecated: Optional[bool] = Field( + name: str = Field(..., description='Name of the function') + parameters: Dict[str, Any] = Field( + ..., description='JSON Schema object describing the function parameters' + ) + type: Literal['FunctionTool'] = Field(..., description='The type of tool') + + +class Status2(str, Enum): + in_progress = 'in_progress' + completed = 'completed' + incomplete = 'incomplete' + + +class Type6(str, Enum): + function_call = 'function_call' + + +class FunctionToolCall(BaseModel): + arguments: str = Field( + ..., description='A JSON string of the arguments to pass to the function.\n' + ) + call_id: str = Field( + ..., + description='The unique ID of the function tool call generated by the model.\n', + ) + id: Optional[str] = Field( + None, description='The unique ID of the function tool call.\n' + ) + name: str = Field(..., description='The name of the function to run.\n') + status: Optional[Status2] = Field( None, - description='Indicates if the node is deprecated. Deprecated nodes are hidden in the UI.', + description='The status of the item. One of `in_progress`, `completed`, or\n`incomplete`. Populated when items are returned via API.\n', ) - experimental: Optional[bool] = Field( + type: Type6 = Field( + ..., description='The type of the function tool call. Always `function_call`.\n' + ) + + +class GeminiCitation(BaseModel): + authors: Optional[List[str]] = None + endIndex: Optional[int] = None + license: Optional[str] = None + publicationDate: Optional[date] = None + startIndex: Optional[int] = None + title: Optional[str] = None + uri: Optional[str] = None + + +class GeminiCitationMetadata(BaseModel): + citations: Optional[List[GeminiCitation]] = None + + +class Role1(str, Enum): + user = 'user' + model = 'model' + + +class GeminiFunctionDeclaration(BaseModel): + description: Optional[str] = None + name: str + parameters: Dict[str, Any] = Field( + ..., description='JSON schema for the function parameters' + ) + + +class GeminiGenerationConfig(BaseModel): + maxOutputTokens: Optional[int] = Field( None, - description='Indicates if the node is experimental, subject to changes or removal.', + description='Maximum number of tokens that can be generated in the response. A token is approximately 4 characters. 100 tokens correspond to roughly 60-80 words.\n', + examples=[2048], + ge=16, + le=8192, ) - output_is_list: Optional[List[bool]] = Field( - None, description='Boolean values indicating if each output is a list.' - ) - return_names: Optional[str] = Field( - None, description='Names of the outputs for clarity in workflows.' - ) - return_types: Optional[str] = Field( - None, description='Specifies the types of outputs produced by the node.' - ) - function: Optional[str] = Field( - None, description='Name of the entry-point function to execute the node.' - ) - - -class ComfyNodeCloudBuildInfo(BaseModel): - project_id: Optional[str] = None - project_number: Optional[str] = None - location: Optional[str] = None - build_id: Optional[str] = None - - -class Error(BaseModel): - message: Optional[str] = Field( - None, description='A clear and concise description of the error.' - ) - details: Optional[List[str]] = Field( + seed: Optional[int] = Field( None, - description='Optional detailed information about the error or hints for resolving it.', + description="When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used. Available for the following models:, gemini-2.5-flash-preview-04-1, gemini-2.5-pro-preview-05-0, gemini-2.0-flash-lite-00, gemini-2.0-flash-001\n", + examples=[343940597], + ) + stopSequences: Optional[List[str]] = None + temperature: Optional[float] = Field( + 1, + description="The temperature is used for sampling during response generation, which occurs when topP and topK are applied. Temperature controls the degree of randomness in token selection. Lower temperatures are good for prompts that require a less open-ended or creative response, while higher temperatures can lead to more diverse or creative results. A temperature of 0 means that the highest probability tokens are always selected. In this case, responses for a given prompt are mostly deterministic, but a small amount of variation is still possible. If the model returns a response that's too generic, too short, or the model gives a fallback response, try increasing the temperature\n", + ge=0.0, + le=2.0, + ) + topK: Optional[int] = Field( + 40, + description="Top-K changes how the model selects tokens for output. A top-K of 1 means the next selected token is the most probable among all tokens in the model's vocabulary. A top-K of 3 means that the next token is selected from among the 3 most probable tokens by using temperature.\n", + examples=[40], + ge=1, + ) + topP: Optional[float] = Field( + 0.95, + description='If specified, nucleus sampling is used.\nTop-P changes how the model selects tokens for output. Tokens are selected from the most (see top-K) to least probable until the sum of their probabilities equals the top-P value. For example, if tokens A, B, and C have a probability of 0.3, 0.2, and 0.1 and the top-P value is 0.5, then the model will select either A or B as the next token by using temperature and excludes C as a candidate.\nSpecify a lower value for less random responses and a higher value for more random responses.\n', + ge=0.0, + le=1.0, ) -class NodeVersionUpdateRequest(BaseModel): - changelog: Optional[str] = Field( - None, description='The changelog describing the version changes.' +class GeminiMimeType(str, Enum): + application_pdf = 'application/pdf' + audio_mpeg = 'audio/mpeg' + audio_mp3 = 'audio/mp3' + audio_wav = 'audio/wav' + image_png = 'image/png' + image_jpeg = 'image/jpeg' + image_webp = 'image/webp' + text_plain = 'text/plain' + video_mov = 'video/mov' + video_mpeg = 'video/mpeg' + video_mp4 = 'video/mp4' + video_mpg = 'video/mpg' + video_avi = 'video/avi' + video_wmv = 'video/wmv' + video_mpegps = 'video/mpegps' + video_flv = 'video/flv' + + +class GeminiOffset(BaseModel): + nanos: Optional[int] = Field( + None, + description='Signed fractions of a second at nanosecond resolution. Negative second values with fractions must still have non-negative nanos values.\n', + examples=[0], + ge=0, + le=999999999, ) - deprecated: Optional[bool] = Field( - None, description='Whether the version is deprecated.' + seconds: Optional[int] = Field( + None, + description='Signed seconds of the span of time. Must be from -315,576,000,000 to +315,576,000,000 inclusive.\n', + examples=[60], + ge=-315576000000, + le=315576000000, ) -class NodeStatus(str, Enum): - NodeStatusActive = 'NodeStatusActive' - NodeStatusDeleted = 'NodeStatusDeleted' - NodeStatusBanned = 'NodeStatusBanned' +class GeminiSafetyCategory(str, Enum): + HARM_CATEGORY_SEXUALLY_EXPLICIT = 'HARM_CATEGORY_SEXUALLY_EXPLICIT' + HARM_CATEGORY_HATE_SPEECH = 'HARM_CATEGORY_HATE_SPEECH' + HARM_CATEGORY_HARASSMENT = 'HARM_CATEGORY_HARASSMENT' + HARM_CATEGORY_DANGEROUS_CONTENT = 'HARM_CATEGORY_DANGEROUS_CONTENT' -class NodeVersionStatus(str, Enum): - NodeVersionStatusActive = 'NodeVersionStatusActive' - NodeVersionStatusDeleted = 'NodeVersionStatusDeleted' - NodeVersionStatusBanned = 'NodeVersionStatusBanned' - NodeVersionStatusPending = 'NodeVersionStatusPending' - NodeVersionStatusFlagged = 'NodeVersionStatusFlagged' +class Probability(str, Enum): + NEGLIGIBLE = 'NEGLIGIBLE' + LOW = 'LOW' + MEDIUM = 'MEDIUM' + HIGH = 'HIGH' + UNKNOWN = 'UNKNOWN' -class PublisherStatus(str, Enum): - PublisherStatusActive = 'PublisherStatusActive' - PublisherStatusBanned = 'PublisherStatusBanned' - - -class WorkflowRunStatus(str, Enum): - WorkflowRunStatusStarted = 'WorkflowRunStatusStarted' - WorkflowRunStatusFailed = 'WorkflowRunStatusFailed' - WorkflowRunStatusCompleted = 'WorkflowRunStatusCompleted' - - -class MachineStats(BaseModel): - machine_name: Optional[str] = Field(None, description='Name of the machine.') - os_version: Optional[str] = Field( - None, description='The operating system version. eg. Ubuntu Linux 20.04' - ) - gpu_type: Optional[str] = Field( - None, description='The GPU type. eg. NVIDIA Tesla K80' - ) - cpu_capacity: Optional[str] = Field(None, description='Total CPU on the machine.') - initial_cpu: Optional[str] = Field( - None, description='Initial CPU available before the job starts.' - ) - memory_capacity: Optional[str] = Field( - None, description='Total memory on the machine.' - ) - initial_ram: Optional[str] = Field( - None, description='Initial RAM available before the job starts.' - ) - vram_time_series: Optional[Dict[str, Any]] = Field( - None, description='Time series of VRAM usage.' - ) - disk_capacity: Optional[str] = Field( - None, description='Total disk capacity on the machine.' - ) - initial_disk: Optional[str] = Field( - None, description='Initial disk available before the job starts.' - ) - pip_freeze: Optional[str] = Field(None, description='The pip freeze output') - - -class Customer(BaseModel): - id: str = Field(..., description='The firebase UID of the user') - email: Optional[str] = Field(None, description='The email address for this user') - name: Optional[str] = Field(None, description='The name for this user') - createdAt: Optional[datetime] = Field( - None, description='The date and time the user was created' - ) - updatedAt: Optional[datetime] = Field( - None, description='The date and time the user was last updated' +class GeminiSafetyRating(BaseModel): + category: Optional[GeminiSafetyCategory] = None + probability: Optional[Probability] = Field( + None, + description='The probability that the content violates the specified safety category', ) -class MagicPrompt(str, Enum): - ON = 'ON' +class GeminiSafetyThreshold(str, Enum): OFF = 'OFF' + BLOCK_NONE = 'BLOCK_NONE' + BLOCK_LOW_AND_ABOVE = 'BLOCK_LOW_AND_ABOVE' + BLOCK_MEDIUM_AND_ABOVE = 'BLOCK_MEDIUM_AND_ABOVE' + BLOCK_ONLY_HIGH = 'BLOCK_ONLY_HIGH' -class ColorPalette(BaseModel): - name: str = Field(..., description='Name of the color palette', examples=['PASTEL']) +class GeminiTextPart(BaseModel): + text: Optional[str] = Field( + None, + description='A text prompt or code snippet.', + examples=['Answer as concisely as possible'], + ) -class StyleCode(RootModel[str]): - root: str = Field(..., pattern='^[0-9A-Fa-f]{8}$') +class GeminiTool(BaseModel): + functionDeclarations: Optional[List[GeminiFunctionDeclaration]] = None -class StyleType(str, Enum): - GENERAL = 'GENERAL' +class GeminiVideoMetadata(BaseModel): + endOffset: Optional[GeminiOffset] = None + startOffset: Optional[GeminiOffset] = None class IdeogramColorPalette1(BaseModel): @@ -262,17 +471,34 @@ class IdeogramColorPalette( class ImageRequest(BaseModel): - prompt: str = Field( - ..., description='Required. The prompt to use to generate the image.' - ) aspect_ratio: Optional[str] = Field( None, description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.", ) - model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')") + color_palette: Optional[Dict[str, Any]] = Field( + None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.' + ) magic_prompt_option: Optional[str] = Field( None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')." ) + model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')") + negative_prompt: Optional[str] = Field( + None, + description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.', + ) + num_images: Optional[int] = Field( + 1, + description='Optional. Number of images to generate (1-8). Defaults to 1.', + ge=1, + le=8, + ) + prompt: str = Field( + ..., description='Required. The prompt to use to generate the image.' + ) + resolution: Optional[str] = Field( + None, + description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.", + ) seed: Optional[int] = Field( None, description='Optional. A number between 0 and 2147483647.', @@ -283,23 +509,6 @@ class ImageRequest(BaseModel): None, description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.", ) - negative_prompt: Optional[str] = Field( - None, - description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.', - ) - num_images: Optional[int] = Field( - 1, - description='Optional. Number of images to generate (1-8). Defaults to 1.', - ge=1, - le=8, - ) - resolution: Optional[str] = Field( - None, - description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.", - ) - color_palette: Optional[Dict[str, Any]] = Field( - None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.' - ) class IdeogramGenerateRequest(BaseModel): @@ -309,23 +518,23 @@ class IdeogramGenerateRequest(BaseModel): class Datum(BaseModel): + is_image_safe: Optional[bool] = Field( + None, description='Indicates whether the image is considered safe.' + ) prompt: Optional[str] = Field( None, description='The prompt used to generate this image.' ) resolution: Optional[str] = Field( None, description="The resolution of the generated image (e.g., '1024x1024')." ) - is_image_safe: Optional[bool] = Field( - None, description='Indicates whether the image is considered safe.' - ) seed: Optional[int] = Field( None, description='The seed value used for this generation.' ) - url: Optional[str] = Field(None, description='URL to the generated image.') style_type: Optional[str] = Field( None, description="The style type used for generation (e.g., 'REALISTIC', 'ANIME').", ) + url: Optional[str] = Field(None, description='URL to the generated image.') class IdeogramGenerateResponse(BaseModel): @@ -337,49 +546,17 @@ class IdeogramGenerateResponse(BaseModel): ) -class RenderingSpeed1(str, Enum): - TURBO = 'TURBO' - DEFAULT = 'DEFAULT' - QUALITY = 'QUALITY' - - -class MagicPrompt1(str, Enum): - AUTO = 'AUTO' - ON = 'ON' - OFF = 'OFF' - - -class StyleType1(str, Enum): - AUTO = 'AUTO' - GENERAL = 'GENERAL' - REALISTIC = 'REALISTIC' - DESIGN = 'DESIGN' - - -class IdeogramV3RemixRequest(BaseModel): - image: Optional[StrictBytes] = None - prompt: str - image_weight: Optional[int] = Field(50, ge=1, le=100) - seed: Optional[int] = Field(None, ge=0, le=2147483647) - resolution: Optional[str] = None - aspect_ratio: Optional[str] = None - rendering_speed: Optional[RenderingSpeed1] = None - magic_prompt: Optional[MagicPrompt1] = None - negative_prompt: Optional[str] = None - num_images: Optional[int] = Field(None, ge=1, le=8) - color_palette: Optional[Dict[str, Any]] = None - style_codes: Optional[List[str]] = None - style_type: Optional[StyleType1] = None - style_reference_images: Optional[List[StrictBytes]] = None +class StyleCode(RootModel[str]): + root: str = Field(..., pattern='^[0-9A-Fa-f]{8}$') class Datum1(BaseModel): + is_image_safe: Optional[bool] = None prompt: Optional[str] = None resolution: Optional[str] = None - is_image_safe: Optional[bool] = None seed: Optional[int] = None - url: Optional[str] = None style_type: Optional[str] = None + url: Optional[str] = None class IdeogramV3IdeogramResponse(BaseModel): @@ -387,74 +564,201 @@ class IdeogramV3IdeogramResponse(BaseModel): data: Optional[List[Datum1]] = None +class RenderingSpeed1(str, Enum): + TURBO = 'TURBO' + DEFAULT = 'DEFAULT' + QUALITY = 'QUALITY' + + class IdeogramV3ReframeRequest(BaseModel): - image: Optional[StrictBytes] = None - resolution: str - num_images: Optional[int] = Field(None, ge=1, le=8) - seed: Optional[int] = Field(None, ge=0, le=2147483647) - rendering_speed: Optional[RenderingSpeed1] = None color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + rendering_speed: Optional[RenderingSpeed1] = None + resolution: str + seed: Optional[int] = Field(None, ge=0, le=2147483647) style_codes: Optional[List[str]] = None style_reference_images: Optional[List[StrictBytes]] = None +class MagicPrompt(str, Enum): + AUTO = 'AUTO' + ON = 'ON' + OFF = 'OFF' + + +class StyleType(str, Enum): + AUTO = 'AUTO' + GENERAL = 'GENERAL' + REALISTIC = 'REALISTIC' + DESIGN = 'DESIGN' + + +class IdeogramV3RemixRequest(BaseModel): + aspect_ratio: Optional[str] = None + color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + image_weight: Optional[int] = Field(50, ge=1, le=100) + magic_prompt: Optional[MagicPrompt] = None + negative_prompt: Optional[str] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + prompt: str + rendering_speed: Optional[RenderingSpeed1] = None + resolution: Optional[str] = None + seed: Optional[int] = Field(None, ge=0, le=2147483647) + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + style_type: Optional[StyleType] = None + + class IdeogramV3ReplaceBackgroundRequest(BaseModel): - image: Optional[StrictBytes] = None - prompt: str - magic_prompt: Optional[MagicPrompt1] = None - num_images: Optional[int] = Field(None, ge=1, le=8) - seed: Optional[int] = Field(None, ge=0, le=2147483647) - rendering_speed: Optional[RenderingSpeed1] = None color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + magic_prompt: Optional[MagicPrompt] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + prompt: str + rendering_speed: Optional[RenderingSpeed1] = None + seed: Optional[int] = Field(None, ge=0, le=2147483647) style_codes: Optional[List[str]] = None style_reference_images: Optional[List[StrictBytes]] = None -class KlingTaskStatus(str, Enum): - submitted = 'submitted' - processing = 'processing' - succeed = 'succeed' - failed = 'failed' +class ColorPalette(BaseModel): + name: str = Field(..., description='Name of the color palette', examples=['PASTEL']) -class KlingVideoGenModelName(str, Enum): - kling_v1 = 'kling-v1' - kling_v1_5 = 'kling-v1-5' - kling_v1_6 = 'kling-v1-6' - kling_v2_master = 'kling-v2-master' +class MagicPrompt2(str, Enum): + ON = 'ON' + OFF = 'OFF' -class KlingVideoGenMode(str, Enum): - std = 'std' - pro = 'pro' +class StyleType1(str, Enum): + GENERAL = 'GENERAL' -class KlingVideoGenAspectRatio(str, Enum): - field_16_9 = '16:9' - field_9_16 = '9:16' +class ImagenImageGenerationInstance(BaseModel): + prompt: str = Field(..., description='Text prompt for image generation') + + +class AspectRatio(str, Enum): field_1_1 = '1:1' + field_9_16 = '9:16' + field_16_9 = '16:9' + field_3_4 = '3:4' + field_4_3 = '4:3' -class KlingVideoGenDuration(str, Enum): - field_5 = '5' - field_10 = '10' +class PersonGeneration(str, Enum): + dont_allow = 'dont_allow' + allow_adult = 'allow_adult' + allow_all = 'allow_all' -class KlingVideoGenCfgScale(RootModel[float]): - root: float = Field( - ..., - description="Flexibility in video generation. The higher the value, the lower the model's degree of flexibility, and the stronger the relevance to the user's prompt.", - ge=0.0, - le=1.0, +class SafetySetting(str, Enum): + block_most = 'block_most' + block_some = 'block_some' + block_few = 'block_few' + block_fewest = 'block_fewest' + + +class ImagenImagePrediction(BaseModel): + bytesBase64Encoded: Optional[str] = Field( + None, description='Base64-encoded image content' + ) + mimeType: Optional[str] = Field( + None, description='MIME type of the generated image' + ) + prompt: Optional[str] = Field( + None, description='Enhanced or rewritten prompt used to generate this image' ) -class KlingCameraControlType(str, Enum): - simple = 'simple' - down_back = 'down_back' - forward_up = 'forward_up' - right_turn_forward = 'right_turn_forward' - left_turn_forward = 'left_turn_forward' +class MimeType(str, Enum): + image_png = 'image/png' + image_jpeg = 'image/jpeg' + + +class ImagenOutputOptions(BaseModel): + compressionQuality: Optional[int] = Field(None, ge=0, le=100) + mimeType: Optional[MimeType] = None + + +class Includable(str, Enum): + file_search_call_results = 'file_search_call.results' + message_input_image_image_url = 'message.input_image.image_url' + computer_call_output_output_image_url = 'computer_call_output.output.image_url' + + +class Type7(str, Enum): + input_file = 'input_file' + + +class InputFileContent(BaseModel): + file_data: Optional[str] = Field( + None, description='The content of the file to be sent to the model.\n' + ) + file_id: Optional[str] = Field( + None, description='The ID of the file to be sent to the model.' + ) + filename: Optional[str] = Field( + None, description='The name of the file to be sent to the model.' + ) + type: Type7 = Field( + ..., description='The type of the input item. Always `input_file`.' + ) + + +class Detail(str, Enum): + low = 'low' + high = 'high' + auto = 'auto' + + +class Type8(str, Enum): + input_image = 'input_image' + + +class InputImageContent(BaseModel): + detail: Detail = Field( + ..., + description='The detail level of the image to be sent to the model. One of `high`, `low`, or `auto`. Defaults to `auto`.', + ) + file_id: Optional[str] = Field( + None, description='The ID of the file to be sent to the model.' + ) + image_url: Optional[str] = Field( + None, + description='The URL of the image to be sent to the model. A fully qualified URL or base64 encoded image in a data URL.', + ) + type: Type8 = Field( + ..., description='The type of the input item. Always `input_image`.' + ) + + +class Role3(str, Enum): + user = 'user' + system = 'system' + developer = 'developer' + + +class Type9(str, Enum): + message = 'message' + + +class Type10(str, Enum): + input_text = 'input_text' + + +class InputTextContent(BaseModel): + text: str = Field(..., description='The text input to the model.') + type: Type10 = Field( + ..., description='The type of the input item. Always `input_text`.' + ) + + +class KlingAudioUploadType(str, Enum): + file = 'file' + url = 'url' class KlingCameraConfig(BaseModel): @@ -464,27 +768,27 @@ class KlingCameraConfig(BaseModel): ge=-10.0, le=10.0, ) - vertical: Optional[float] = Field( - None, - description="Controls camera's movement along vertical axis (y-axis). Negative indicates downward, positive indicates upward.", - ge=-10.0, - le=10.0, - ) pan: Optional[float] = Field( None, description="Controls camera's rotation in vertical plane (x-axis). Negative indicates downward rotation, positive indicates upward rotation.", ge=-10.0, le=10.0, ) + roll: Optional[float] = Field( + None, + description="Controls camera's rolling amount (z-axis). Negative indicates counterclockwise, positive indicates clockwise.", + ge=-10.0, + le=10.0, + ) tilt: Optional[float] = Field( None, description="Controls camera's rotation in horizontal plane (y-axis). Negative indicates left rotation, positive indicates right rotation.", ge=-10.0, le=10.0, ) - roll: Optional[float] = Field( + vertical: Optional[float] = Field( None, - description="Controls camera's rolling amount (z-axis). Negative indicates counterclockwise, positive indicates clockwise.", + description="Controls camera's movement along vertical axis (y-axis). Negative indicates downward, positive indicates upward.", ge=-10.0, le=10.0, ) @@ -496,39 +800,12 @@ class KlingCameraConfig(BaseModel): ) -class KlingVideoResult(BaseModel): - id: Optional[str] = Field(None, description='Generated video ID') - url: Optional[AnyUrl] = Field(None, description='URL for generated video') - duration: Optional[str] = Field(None, description='Total video duration') - - -class KlingAudioUploadType(str, Enum): - file = 'file' - url = 'url' - - -class KlingLipSyncMode(str, Enum): - text2video = 'text2video' - audio2video = 'audio2video' - - -class KlingLipSyncVoiceLanguage(str, Enum): - zh = 'zh' - en = 'en' - - -class KlingDualCharacterEffectsScene(str, Enum): - hug = 'hug' - kiss = 'kiss' - heart_gesture = 'heart_gesture' - - -class KlingSingleImageEffectsScene(str, Enum): - bloombloom = 'bloombloom' - dizzydizzy = 'dizzydizzy' - fuzzyfuzzy = 'fuzzyfuzzy' - squish = 'squish' - expansion = 'expansion' +class KlingCameraControlType(str, Enum): + simple = 'simple' + down_back = 'down_back' + forward_up = 'forward_up' + right_turn_forward = 'right_turn_forward' + left_turn_forward = 'left_turn_forward' class KlingCharacterEffectModelName(str, Enum): @@ -537,18 +814,50 @@ class KlingCharacterEffectModelName(str, Enum): kling_v1_6 = 'kling-v1-6' -class KlingSingleImageEffectModelName(str, Enum): - kling_v1_6 = 'kling-v1-6' - - -class KlingSingleImageEffectDuration(str, Enum): - field_5 = '5' +class KlingDualCharacterEffectsScene(str, Enum): + hug = 'hug' + kiss = 'kiss' + heart_gesture = 'heart_gesture' class KlingDualCharacterImages(RootModel[List[str]]): root: List[str] = Field(..., max_length=2, min_length=2) +class KlingErrorResponse(BaseModel): + code: int = Field( + ..., + description='- 1000: Authentication failed\n- 1001: Authorization is empty\n- 1002: Authorization is invalid\n- 1003: Authorization is not yet valid\n- 1004: Authorization has expired\n- 1100: Account exception\n- 1101: Account in arrears (postpaid scenario)\n- 1102: Resource pack depleted or expired (prepaid scenario)\n- 1103: Unauthorized access to requested resource\n- 1200: Invalid request parameters\n- 1201: Invalid parameters\n- 1202: Invalid request method\n- 1203: Requested resource does not exist\n- 1300: Trigger platform strategy\n- 1301: Trigger content security policy\n- 1302: API request too frequent\n- 1303: Concurrency/QPS exceeds limit\n- 1304: Trigger IP whitelist policy\n- 5000: Internal server error\n- 5001: Service temporarily unavailable\n- 5002: Server internal timeout\n', + ) + message: str = Field(..., description='Human-readable error message') + request_id: str = Field( + ..., description='Request ID for tracking and troubleshooting' + ) + + +class Trajectory(BaseModel): + x: Optional[int] = Field( + None, + description='The horizontal coordinate of trajectory point. Based on bottom-left corner of image as origin (0,0).', + ) + y: Optional[int] = Field( + None, + description='The vertical coordinate of trajectory point. Based on bottom-left corner of image as origin (0,0).', + ) + + +class DynamicMask(BaseModel): + mask: Optional[AnyUrl] = Field( + None, + description='Dynamic Brush Application Area (Mask image created by users using the motion brush). The aspect ratio must match the input image.', + ) + trajectories: Optional[List[Trajectory]] = None + + +class TaskInfo(BaseModel): + external_task_id: Optional[str] = None + + class KlingImageGenAspectRatio(str, Enum): field_16_9 = '16:9' field_9_16 = '9:16' @@ -571,278 +880,42 @@ class KlingImageGenModelName(str, Enum): kling_v2 = 'kling-v2' +class KlingImageGenerationsRequest(BaseModel): + aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9' + callback_url: Optional[AnyUrl] = Field( + None, description='The callback notification address' + ) + human_fidelity: Optional[float] = Field( + 0.45, description='Subject reference similarity', ge=0.0, le=1.0 + ) + image: Optional[str] = Field( + None, description='Reference Image - Base64 encoded string or image URL' + ) + image_fidelity: Optional[float] = Field( + 0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0 + ) + image_reference: Optional[KlingImageGenImageReferenceType] = None + model_name: Optional[KlingImageGenModelName] = 'kling-v1' + n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9) + negative_prompt: Optional[str] = Field( + None, description='Negative text prompt', max_length=200 + ) + prompt: str = Field(..., description='Positive text prompt', max_length=500) + + class KlingImageResult(BaseModel): index: Optional[int] = Field(None, description='Image Number (0-9)') url: Optional[AnyUrl] = Field(None, description='URL for generated image') -class KlingVirtualTryOnModelName(str, Enum): - kolors_virtual_try_on_v1 = 'kolors-virtual-try-on-v1' - kolors_virtual_try_on_v1_5 = 'kolors-virtual-try-on-v1-5' +class KlingLipSyncMode(str, Enum): + text2video = 'text2video' + audio2video = 'audio2video' -class TaskInfo(BaseModel): - external_task_id: Optional[str] = None - - -class TaskResult(BaseModel): - videos: Optional[List[KlingVideoResult]] = None - - -class Data(BaseModel): - task_id: Optional[str] = Field(None, description='Task ID') - task_status: Optional[KlingTaskStatus] = None - task_info: Optional[TaskInfo] = None - created_at: Optional[int] = Field(None, description='Task creation time') - updated_at: Optional[int] = Field(None, description='Task update time') - task_result: Optional[TaskResult] = None - - -class KlingText2VideoResponse(BaseModel): - code: Optional[int] = Field(None, description='Error code') - message: Optional[str] = Field(None, description='Error message') - request_id: Optional[str] = Field(None, description='Request ID') - data: Optional[Data] = None - - -class Trajectory(BaseModel): - x: Optional[int] = Field( - None, - description='The horizontal coordinate of trajectory point. Based on bottom-left corner of image as origin (0,0).', - ) - y: Optional[int] = Field( - None, - description='The vertical coordinate of trajectory point. Based on bottom-left corner of image as origin (0,0).', - ) - - -class DynamicMask(BaseModel): - mask: Optional[AnyUrl] = Field( - None, - description='Dynamic Brush Application Area (Mask image created by users using the motion brush). The aspect ratio must match the input image.', - ) - trajectories: Optional[List[Trajectory]] = None - - -class Data1(BaseModel): - task_id: Optional[str] = Field(None, description='Task ID') - task_status: Optional[KlingTaskStatus] = None - task_info: Optional[TaskInfo] = None - created_at: Optional[int] = Field(None, description='Task creation time') - updated_at: Optional[int] = Field(None, description='Task update time') - task_result: Optional[TaskResult] = None - - -class KlingImage2VideoResponse(BaseModel): - code: Optional[int] = Field(None, description='Error code') - message: Optional[str] = Field(None, description='Error message') - request_id: Optional[str] = Field(None, description='Request ID') - data: Optional[Data1] = None - - -class KlingVideoExtendRequest(BaseModel): - video_id: Optional[str] = Field( - None, - description='The ID of the video to be extended. Supports videos generated by text-to-video, image-to-video, and previous video extension operations. Cannot exceed 3 minutes total duration after extension.', - ) - prompt: Optional[str] = Field( - None, - description='Positive text prompt for guiding the video extension', - max_length=2500, - ) - negative_prompt: Optional[str] = Field( - None, - description='Negative text prompt for elements to avoid in the extended video', - max_length=2500, - ) - cfg_scale: Optional[KlingVideoGenCfgScale] = Field( - default_factory=lambda: KlingVideoGenCfgScale.model_validate(0.5) - ) - callback_url: Optional[AnyUrl] = Field( - None, - description='The callback notification address. Server will notify when the task status changes.', - ) - - -class Data2(BaseModel): - task_id: Optional[str] = Field(None, description='Task ID') - task_status: Optional[KlingTaskStatus] = None - task_info: Optional[TaskInfo] = None - created_at: Optional[int] = Field(None, description='Task creation time') - updated_at: Optional[int] = Field(None, description='Task update time') - task_result: Optional[TaskResult] = None - - -class KlingVideoExtendResponse(BaseModel): - code: Optional[int] = Field(None, description='Error code') - message: Optional[str] = Field(None, description='Error message') - request_id: Optional[str] = Field(None, description='Request ID') - data: Optional[Data2] = None - - -class KlingLipSyncInputObject(BaseModel): - video_id: Optional[str] = Field( - None, - description='The ID of the video generated by Kling AI. Only supports 5-second and 10-second videos generated within the last 30 days.', - ) - video_url: Optional[str] = Field( - None, - description='Get link for uploaded video. Video files support .mp4/.mov, file size does not exceed 100MB, video length between 2-10s.', - ) - mode: KlingLipSyncMode - text: Optional[str] = Field( - None, - description='Text Content for Lip-Sync Video Generation. Required when mode is text2video. Maximum length is 120 characters.', - ) - voice_id: Optional[str] = Field( - None, - description='Voice ID. Required when mode is text2video. The system offers a variety of voice options to choose from.', - ) - voice_language: Optional[KlingLipSyncVoiceLanguage] = 'en' - voice_speed: Optional[float] = Field( - 1, - description='Speech Rate. Valid range: 0.8~2.0, accurate to one decimal place.', - ge=0.8, - le=2.0, - ) - audio_type: Optional[KlingAudioUploadType] = None - audio_file: Optional[str] = Field( - None, - description='Local Path of Audio File. Supported formats: .mp3/.wav/.m4a/.aac, maximum file size of 5MB. Base64 code.', - ) - audio_url: Optional[str] = Field( - None, - description='Audio File Download URL. Supported formats: .mp3/.wav/.m4a/.aac, maximum file size of 5MB.', - ) - - -class KlingLipSyncRequest(BaseModel): - input: KlingLipSyncInputObject - callback_url: Optional[AnyUrl] = Field( - None, - description='The callback notification address. Server will notify when the task status changes.', - ) - - -class Data3(BaseModel): - task_id: Optional[str] = Field(None, description='Task ID') - task_status: Optional[KlingTaskStatus] = None - task_info: Optional[TaskInfo] = None - created_at: Optional[int] = Field(None, description='Task creation time') - updated_at: Optional[int] = Field(None, description='Task update time') - task_result: Optional[TaskResult] = None - - -class KlingLipSyncResponse(BaseModel): - code: Optional[int] = Field(None, description='Error code') - message: Optional[str] = Field(None, description='Error message') - request_id: Optional[str] = Field(None, description='Request ID') - data: Optional[Data3] = None - - -class KlingSingleImageEffectInput(BaseModel): - model_name: KlingSingleImageEffectModelName - image: str = Field( - ..., - description='Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1.', - ) - duration: KlingSingleImageEffectDuration - - -class KlingDualCharacterEffectInput(BaseModel): - model_name: Optional[KlingCharacterEffectModelName] = 'kling-v1' - mode: Optional[KlingVideoGenMode] = 'std' - images: KlingDualCharacterImages - duration: KlingVideoGenDuration - - -class Data4(BaseModel): - task_id: Optional[str] = Field(None, description='Task ID') - task_status: Optional[KlingTaskStatus] = None - task_info: Optional[TaskInfo] = None - created_at: Optional[int] = Field(None, description='Task creation time') - updated_at: Optional[int] = Field(None, description='Task update time') - task_result: Optional[TaskResult] = None - - -class KlingVideoEffectsResponse(BaseModel): - code: Optional[int] = Field(None, description='Error code') - message: Optional[str] = Field(None, description='Error message') - request_id: Optional[str] = Field(None, description='Request ID') - data: Optional[Data4] = None - - -class KlingImageGenerationsRequest(BaseModel): - model_name: Optional[KlingImageGenModelName] = 'kling-v1' - prompt: str = Field(..., description='Positive text prompt', max_length=500) - negative_prompt: Optional[str] = Field( - None, description='Negative text prompt', max_length=200 - ) - image: Optional[str] = Field( - None, description='Reference Image - Base64 encoded string or image URL' - ) - image_reference: Optional[KlingImageGenImageReferenceType] = None - image_fidelity: Optional[float] = Field( - 0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0 - ) - human_fidelity: Optional[float] = Field( - 0.45, description='Subject reference similarity', ge=0.0, le=1.0 - ) - n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9) - aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9' - callback_url: Optional[AnyUrl] = Field( - None, description='The callback notification address' - ) - - -class TaskResult5(BaseModel): - images: Optional[List[KlingImageResult]] = None - - -class Data5(BaseModel): - task_id: Optional[str] = Field(None, description='Task ID') - task_status: Optional[KlingTaskStatus] = None - task_status_msg: Optional[str] = Field(None, description='Task status information') - created_at: Optional[int] = Field(None, description='Task creation time') - updated_at: Optional[int] = Field(None, description='Task update time') - task_result: Optional[TaskResult5] = None - - -class KlingImageGenerationsResponse(BaseModel): - code: Optional[int] = Field(None, description='Error code') - message: Optional[str] = Field(None, description='Error message') - request_id: Optional[str] = Field(None, description='Request ID') - data: Optional[Data5] = None - - -class KlingVirtualTryOnRequest(BaseModel): - model_name: Optional[KlingVirtualTryOnModelName] = 'kolors-virtual-try-on-v1' - human_image: str = Field( - ..., description='Reference human image - Base64 encoded string or image URL' - ) - cloth_image: Optional[str] = Field( - None, - description='Reference clothing image - Base64 encoded string or image URL', - ) - callback_url: Optional[AnyUrl] = Field( - None, description='The callback notification address' - ) - - -class Data6(BaseModel): - task_id: Optional[str] = Field(None, description='Task ID') - task_status: Optional[KlingTaskStatus] = None - task_status_msg: Optional[str] = Field(None, description='Task status information') - created_at: Optional[int] = Field(None, description='Task creation time') - updated_at: Optional[int] = Field(None, description='Task update time') - task_result: Optional[TaskResult5] = None - - -class KlingVirtualTryOnResponse(BaseModel): - code: Optional[int] = Field(None, description='Error code') - message: Optional[str] = Field(None, description='Error message') - request_id: Optional[str] = Field(None, description='Request ID') - data: Optional[Data6] = None +class KlingLipSyncVoiceLanguage(str, Enum): + zh = 'zh' + en = 'en' class ResourcePackType(str, Enum): @@ -850,7 +923,7 @@ class ResourcePackType(str, Enum): constant_period = 'constant_period' -class Status(str, Enum): +class Status4(str, Enum): toBeOnline = 'toBeOnline' online = 'online' expired = 'expired' @@ -858,29 +931,29 @@ class Status(str, Enum): class ResourcePackSubscribeInfo(BaseModel): - resource_pack_name: Optional[str] = Field(None, description='Resource package name') - resource_pack_id: Optional[str] = Field(None, description='Resource package ID') - resource_pack_type: Optional[ResourcePackType] = Field( - None, - description='Resource package type (decreasing_total=decreasing total, constant_period=constant periodicity)', - ) - total_quantity: Optional[float] = Field(None, description='Total quantity') - remaining_quantity: Optional[float] = Field( - None, description='Remaining quantity (updated with a 12-hour delay)' - ) - purchase_time: Optional[int] = Field( - None, description='Purchase time, Unix timestamp in ms' - ) effective_time: Optional[int] = Field( None, description='Effective time, Unix timestamp in ms' ) invalid_time: Optional[int] = Field( None, description='Expiration time, Unix timestamp in ms' ) - status: Optional[Status] = Field(None, description='Resource Package Status') + purchase_time: Optional[int] = Field( + None, description='Purchase time, Unix timestamp in ms' + ) + remaining_quantity: Optional[float] = Field( + None, description='Remaining quantity (updated with a 12-hour delay)' + ) + resource_pack_id: Optional[str] = Field(None, description='Resource package ID') + resource_pack_name: Optional[str] = Field(None, description='Resource package name') + resource_pack_type: Optional[ResourcePackType] = Field( + None, + description='Resource package type (decreasing_total=decreasing total, constant_period=constant periodicity)', + ) + status: Optional[Status4] = Field(None, description='Resource Package Status') + total_quantity: Optional[float] = Field(None, description='Total quantity') -class Data7(BaseModel): +class Data3(BaseModel): code: Optional[int] = Field(None, description='Error code; 0 indicates success') msg: Optional[str] = Field(None, description='Error information') resource_pack_subscribe_infos: Optional[List[ResourcePackSubscribeInfo]] = Field( @@ -890,137 +963,313 @@ class Data7(BaseModel): class KlingResourcePackageResponse(BaseModel): code: Optional[int] = Field(None, description='Error code; 0 indicates success') + data: Optional[Data3] = None message: Optional[str] = Field(None, description='Error information') request_id: Optional[str] = Field( None, description='Request ID, generated by the system, used to track requests and troubleshoot problems', ) + + +class KlingSingleImageEffectDuration(str, Enum): + field_5 = '5' + + +class KlingSingleImageEffectModelName(str, Enum): + kling_v1_6 = 'kling-v1-6' + + +class KlingSingleImageEffectsScene(str, Enum): + bloombloom = 'bloombloom' + dizzydizzy = 'dizzydizzy' + fuzzyfuzzy = 'fuzzyfuzzy' + squish = 'squish' + expansion = 'expansion' + + +class KlingTaskStatus(str, Enum): + submitted = 'submitted' + processing = 'processing' + succeed = 'succeed' + failed = 'failed' + + +class KlingTextToVideoModelName(str, Enum): + kling_v1 = 'kling-v1' + kling_v1_6 = 'kling-v1-6' + + +class KlingVideoGenAspectRatio(str, Enum): + field_16_9 = '16:9' + field_9_16 = '9:16' + field_1_1 = '1:1' + + +class KlingVideoGenCfgScale(RootModel[float]): + root: float = Field( + ..., + description="Flexibility in video generation. The higher the value, the lower the model's degree of flexibility, and the stronger the relevance to the user's prompt.", + ge=0.0, + le=1.0, + ) + + +class KlingVideoGenDuration(str, Enum): + field_5 = '5' + field_10 = '10' + + +class KlingVideoGenMode(str, Enum): + std = 'std' + pro = 'pro' + + +class KlingVideoGenModelName(str, Enum): + kling_v1 = 'kling-v1' + kling_v1_5 = 'kling-v1-5' + kling_v1_6 = 'kling-v1-6' + kling_v2_master = 'kling-v2-master' + + +class KlingVideoResult(BaseModel): + duration: Optional[str] = Field(None, description='Total video duration') + id: Optional[str] = Field(None, description='Generated video ID') + url: Optional[AnyUrl] = Field(None, description='URL for generated video') + + +class KlingVirtualTryOnModelName(str, Enum): + kolors_virtual_try_on_v1 = 'kolors-virtual-try-on-v1' + kolors_virtual_try_on_v1_5 = 'kolors-virtual-try-on-v1-5' + + +class KlingVirtualTryOnRequest(BaseModel): + callback_url: Optional[AnyUrl] = Field( + None, description='The callback notification address' + ) + cloth_image: Optional[str] = Field( + None, + description='Reference clothing image - Base64 encoded string or image URL', + ) + human_image: str = Field( + ..., description='Reference human image - Base64 encoded string or image URL' + ) + model_name: Optional[KlingVirtualTryOnModelName] = 'kolors-virtual-try-on-v1' + + +class TaskResult6(BaseModel): + images: Optional[List[KlingImageResult]] = None + + +class Data7(BaseModel): + created_at: Optional[int] = Field(None, description='Task creation time') + task_id: Optional[str] = Field(None, description='Task ID') + task_result: Optional[TaskResult6] = None + task_status: Optional[KlingTaskStatus] = None + task_status_msg: Optional[str] = Field(None, description='Task status information') + updated_at: Optional[int] = Field(None, description='Task update time') + + +class KlingVirtualTryOnResponse(BaseModel): + code: Optional[int] = Field(None, description='Error code') data: Optional[Data7] = None + message: Optional[str] = Field(None, description='Error message') + request_id: Optional[str] = Field(None, description='Request ID') -class Object(str, Enum): - event = 'event' +class LumaAspectRatio(str, Enum): + field_1_1 = '1:1' + field_16_9 = '16:9' + field_9_16 = '9:16' + field_4_3 = '4:3' + field_3_4 = '3:4' + field_21_9 = '21:9' + field_9_21 = '9:21' -class Type(str, Enum): - payment_intent_succeeded = 'payment_intent.succeeded' +class LumaAssets(BaseModel): + image: Optional[AnyUrl] = Field(None, description='The URL of the image') + progress_video: Optional[AnyUrl] = Field( + None, description='The URL of the progress video' + ) + video: Optional[AnyUrl] = Field(None, description='The URL of the video') -class StripeRequestInfo(BaseModel): - id: Optional[str] = None - idempotency_key: Optional[str] = None +class GenerationType(str, Enum): + add_audio = 'add_audio' -class Object1(str, Enum): - payment_intent = 'payment_intent' +class LumaAudioGenerationRequest(BaseModel): + callback_url: Optional[AnyUrl] = Field( + None, description='The callback URL for the audio' + ) + generation_type: Optional[GenerationType] = 'add_audio' + negative_prompt: Optional[str] = Field( + None, description='The negative prompt of the audio' + ) + prompt: Optional[str] = Field(None, description='The prompt of the audio') -class StripeAmountDetails(BaseModel): - tip: Optional[Dict[str, Any]] = None +class LumaError(BaseModel): + detail: Optional[str] = Field(None, description='The error message') -class Object2(str, Enum): - charge = 'charge' +class Type11(str, Enum): + generation = 'generation' -class StripeAddress(BaseModel): - city: Optional[str] = None - country: Optional[str] = None - line1: Optional[str] = None - line2: Optional[str] = None - postal_code: Optional[str] = None - state: Optional[str] = None +class LumaGenerationReference(BaseModel): + id: UUID = Field(..., description='The ID of the generation') + type: Literal['generation'] -class StripeOutcome(BaseModel): - advice_code: Optional[Any] = None - network_advice_code: Optional[Any] = None - network_decline_code: Optional[Any] = None - network_status: Optional[str] = None - reason: Optional[Any] = None - risk_level: Optional[str] = None - risk_score: Optional[int] = None - seller_message: Optional[str] = None - type: Optional[str] = None +class GenerationType1(str, Enum): + video = 'video' -class Checks(BaseModel): - address_line1_check: Optional[Any] = None - address_postal_code_check: Optional[Any] = None - cvc_check: Optional[str] = None +class LumaGenerationType(str, Enum): + video = 'video' + image = 'image' -class ExtendedAuthorization(BaseModel): - status: Optional[str] = None +class GenerationType2(str, Enum): + image = 'image' -class IncrementalAuthorization(BaseModel): - status: Optional[str] = None +class LumaImageIdentity(BaseModel): + images: Optional[List[AnyUrl]] = Field( + None, description='The URLs of the image identity' + ) -class Multicapture(BaseModel): - status: Optional[str] = None +class LumaImageModel(str, Enum): + photon_1 = 'photon-1' + photon_flash_1 = 'photon-flash-1' -class NetworkToken(BaseModel): - used: Optional[bool] = None +class LumaImageRef(BaseModel): + url: Optional[AnyUrl] = Field(None, description='The URL of the image reference') + weight: Optional[float] = Field( + None, description='The weight of the image reference' + ) -class Overcapture(BaseModel): - maximum_amount_capturable: Optional[int] = None - status: Optional[str] = None +class Type12(str, Enum): + image = 'image' -class StripeCardDetails(BaseModel): - amount_authorized: Optional[int] = None - authorization_code: Optional[Any] = None - brand: Optional[str] = None - checks: Optional[Checks] = None - country: Optional[str] = None - exp_month: Optional[int] = None - exp_year: Optional[int] = None - extended_authorization: Optional[ExtendedAuthorization] = None - fingerprint: Optional[str] = None - funding: Optional[str] = None - incremental_authorization: Optional[IncrementalAuthorization] = None - installments: Optional[Any] = None - last4: Optional[str] = None - mandate: Optional[Any] = None - multicapture: Optional[Multicapture] = None - network: Optional[str] = None - network_token: Optional[NetworkToken] = None - network_transaction_id: Optional[str] = None - overcapture: Optional[Overcapture] = None - regulated_status: Optional[str] = None - three_d_secure: Optional[Any] = None - wallet: Optional[Any] = None +class LumaImageReference(BaseModel): + type: Literal['image'] + url: AnyUrl = Field(..., description='The URL of the image') -class StripeRefundList(BaseModel): - object: Optional[str] = None - data: Optional[List[Dict[str, Any]]] = None - has_more: Optional[bool] = None - total_count: Optional[int] = None - url: Optional[str] = None +class LumaKeyframe(RootModel[Union[LumaGenerationReference, LumaImageReference]]): + root: Union[LumaGenerationReference, LumaImageReference] = Field( + ..., + description='A keyframe can be either a Generation reference, an Image, or a Video', + discriminator='type', + ) -class Card(BaseModel): - installments: Optional[Any] = None - mandate_options: Optional[Any] = None - network: Optional[Any] = None - request_three_d_secure: Optional[str] = None +class LumaKeyframes(BaseModel): + frame0: Optional[LumaKeyframe] = None + frame1: Optional[LumaKeyframe] = None -class StripePaymentMethodOptions(BaseModel): - card: Optional[Card] = None +class LumaModifyImageRef(BaseModel): + url: Optional[AnyUrl] = Field(None, description='The URL of the image reference') + weight: Optional[float] = Field( + None, description='The weight of the modify image reference' + ) -class StripeShipping(BaseModel): - address: Optional[StripeAddress] = None - carrier: Optional[str] = None - name: Optional[str] = None - phone: Optional[str] = None - tracking_number: Optional[str] = None +class LumaState(str, Enum): + queued = 'queued' + dreaming = 'dreaming' + completed = 'completed' + failed = 'failed' + + +class GenerationType3(str, Enum): + upscale_video = 'upscale_video' + + +class LumaVideoModel(str, Enum): + ray_2 = 'ray-2' + ray_flash_2 = 'ray-flash-2' + ray_1_6 = 'ray-1-6' + + +class LumaVideoModelOutputDuration1(str, Enum): + field_5s = '5s' + field_9s = '9s' + + +class LumaVideoModelOutputDuration( + RootModel[Union[LumaVideoModelOutputDuration1, str]] +): + root: Union[LumaVideoModelOutputDuration1, str] + + +class LumaVideoModelOutputResolution1(str, Enum): + field_540p = '540p' + field_720p = '720p' + field_1080p = '1080p' + field_4k = '4k' + + +class LumaVideoModelOutputResolution( + RootModel[Union[LumaVideoModelOutputResolution1, str]] +): + root: Union[LumaVideoModelOutputResolution1, str] + + +class MinimaxBaseResponse(BaseModel): + status_code: int = Field( + ..., + description='Status code. 0 indicates success, other values indicate errors.', + ) + status_msg: str = Field( + ..., description='Specific error details or success message.' + ) + + +class File(BaseModel): + bytes: Optional[int] = Field(None, description='File size in bytes') + created_at: Optional[int] = Field( + None, description='Unix timestamp when the file was created, in seconds' + ) + download_url: Optional[str] = Field( + None, description='The URL to download the video' + ) + file_id: Optional[int] = Field(None, description='Unique identifier for the file') + filename: Optional[str] = Field(None, description='The name of the file') + purpose: Optional[str] = Field(None, description='The purpose of using the file') + + +class MinimaxFileRetrieveResponse(BaseModel): + base_resp: MinimaxBaseResponse + file: File + + +class Status5(str, Enum): + Queueing = 'Queueing' + Preparing = 'Preparing' + Processing = 'Processing' + Success = 'Success' + Fail = 'Fail' + + +class MinimaxTaskResultResponse(BaseModel): + base_resp: MinimaxBaseResponse + file_id: Optional[str] = Field( + None, + description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.', + ) + status: Status5 = Field( + ..., + description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).", + ) + task_id: str = Field(..., description='The task ID being queried.') class Model(str, Enum): @@ -1043,6 +1292,14 @@ class SubjectReferenceItem(BaseModel): class MinimaxVideoGenerationRequest(BaseModel): + callback_url: Optional[str] = Field( + None, + description='Optional. URL to receive real-time status updates about the video generation task.', + ) + first_frame_image: Optional[str] = Field( + None, + description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.', + ) model: Model = Field( ..., description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01', @@ -1056,927 +1313,175 @@ class MinimaxVideoGenerationRequest(BaseModel): True, description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.', ) - first_frame_image: Optional[str] = Field( - None, - description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.', - ) subject_reference: Optional[List[SubjectReferenceItem]] = Field( None, description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.', ) - callback_url: Optional[str] = Field( - None, - description='Optional. URL to receive real-time status updates about the video generation task.', - ) - - -class MinimaxBaseResponse(BaseModel): - status_code: int = Field( - ..., - description='Status code. 0 indicates success, other values indicate errors.', - ) - status_msg: str = Field( - ..., description='Specific error details or success message.' - ) class MinimaxVideoGenerationResponse(BaseModel): + base_resp: MinimaxBaseResponse task_id: str = Field( ..., description='The task ID for the asynchronous video generation task.' ) - base_resp: MinimaxBaseResponse -class File(BaseModel): - file_id: Optional[int] = Field(None, description='Unique identifier for the file') - bytes: Optional[int] = Field(None, description='File size in bytes') - created_at: Optional[int] = Field( - None, description='Unix timestamp when the file was created, in seconds' +class Truncation(str, Enum): + disabled = 'disabled' + auto = 'auto' + + +class ModelResponseProperties(BaseModel): + instructions: Optional[str] = Field( + None, description='Instructions for the model on how to generate the response' ) - filename: Optional[str] = Field(None, description='The name of the file') - purpose: Optional[str] = Field(None, description='The purpose of using the file') - download_url: Optional[str] = Field( - None, description='The URL to download the video' + max_output_tokens: Optional[int] = Field( + None, description='Maximum number of tokens to generate' + ) + model: Optional[str] = Field( + None, description='The model used to generate the response' + ) + temperature: Optional[float] = Field( + 1, description='Controls randomness in the response', ge=0.0, le=2.0 + ) + top_p: Optional[float] = Field( + 1, + description='Controls diversity of the response via nucleus sampling', + ge=0.0, + le=1.0, + ) + truncation: Optional[Truncation] = Field( + 'disabled', description='How to handle truncation of the response' ) -class MinimaxFileRetrieveResponse(BaseModel): - file: File - base_resp: MinimaxBaseResponse +class Moderation(str, Enum): + low = 'low' + auto = 'auto' -class Status1(str, Enum): - Queueing = 'Queueing' - Preparing = 'Preparing' - Processing = 'Processing' - Success = 'Success' - Fail = 'Fail' - - -class MinimaxTaskResultResponse(BaseModel): - task_id: str = Field(..., description='The task ID being queried.') - status: Status1 = Field( - ..., - description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).", - ) - file_id: Optional[str] = Field( - None, - description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.', - ) - base_resp: MinimaxBaseResponse - - -class OutputFormat(str, Enum): - jpeg = 'jpeg' +class OutputFormat1(str, Enum): png = 'png' - - -class BFLFluxPro11GenerateRequest(BaseModel): - prompt: str = Field(..., description='The main text prompt for image generation') - image_prompt: Optional[str] = Field(None, description='Optional image prompt') - width: int = Field(..., description='Width of the generated image') - height: int = Field(..., description='Height of the generated image') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to use prompt upsampling' - ) - seed: Optional[int] = Field(None, description='Random seed for reproducibility') - safety_tolerance: Optional[int] = Field(None, description='Safety tolerance level') - output_format: Optional[OutputFormat] = Field( - None, description='Output image format' - ) - webhook_url: Optional[str] = Field( - None, description='Optional webhook URL for async processing' - ) - webhook_secret: Optional[str] = Field( - None, description='Optional webhook secret for async processing' - ) - - -class BFLFluxPro11GenerateResponse(BaseModel): - id: str = Field(..., description='Job ID for tracking') - polling_url: str = Field(..., description='URL to poll for results') - - -class BFLFluxProGenerateRequest(BaseModel): - prompt: str = Field(..., description='The text prompt for image generation.') - negative_prompt: Optional[str] = Field( - None, description='The negative prompt for image generation.' - ) - width: int = Field( - ..., description='The width of the image to generate.', ge=64, le=2048 - ) - height: int = Field( - ..., description='The height of the image to generate.', ge=64, le=2048 - ) - num_inference_steps: Optional[int] = Field( - None, description='The number of inference steps.', ge=1, le=100 - ) - guidance_scale: Optional[float] = Field( - None, description='The guidance scale for generation.', ge=1.0, le=20.0 - ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - num_images: Optional[int] = Field( - None, description='The number of images to generate.', ge=1, le=4 - ) - - -class BFLFluxProGenerateResponse(BaseModel): - id: str = Field(..., description='The unique identifier for the generation task.') - polling_url: str = Field(..., description='URL to poll for the generation result.') - - -class Steps(RootModel[int]): - root: int = Field( - ..., - description='Number of steps for the image generation process', - examples=[50], - ge=15, - le=50, - title='Steps', - ) - - -class Guidance(RootModel[float]): - root: float = Field( - ..., - description='Guidance strength for the image generation process', - ge=1.5, - le=100.0, - title='Guidance', - ) - - -class WebhookUrl(RootModel[AnyUrl]): - root: AnyUrl = Field( - ..., description='URL to receive webhook notifications', title='Webhook Url' - ) - - -class BFLAsyncResponse(BaseModel): - id: str = Field(..., title='Id') - polling_url: str = Field(..., title='Polling Url') - - -class BFLAsyncWebhookResponse(BaseModel): - id: str = Field(..., title='Id') - status: str = Field(..., title='Status') - webhook_url: str = Field(..., title='Webhook Url') - - -class Top(RootModel[int]): - root: int = Field( - ..., - description='Number of pixels to expand at the top of the image', - ge=0, - le=2048, - title='Top', - ) - - -class Bottom(RootModel[int]): - root: int = Field( - ..., - description='Number of pixels to expand at the bottom of the image', - ge=0, - le=2048, - title='Bottom', - ) - - -class Left(RootModel[int]): - root: int = Field( - ..., - description='Number of pixels to expand on the left side of the image', - ge=0, - le=2048, - title='Left', - ) - - -class Right(RootModel[int]): - root: int = Field( - ..., - description='Number of pixels to expand on the right side of the image', - ge=0, - le=2048, - title='Right', - ) - - -class CannyLowThreshold(RootModel[int]): - root: int = Field( - ..., - description='Low threshold for Canny edge detection', - ge=0, - le=500, - title='Canny Low Threshold', - ) - - -class CannyHighThreshold(RootModel[int]): - root: int = Field( - ..., - description='High threshold for Canny edge detection', - ge=0, - le=500, - title='Canny High Threshold', - ) - - -class Steps2(RootModel[int]): - root: int = Field( - ..., - description='Number of steps for the image generation process', - ge=15, - le=50, - title='Steps', - ) - - -class Guidance2(RootModel[float]): - root: float = Field( - ..., - description='Guidance strength for the image generation process', - ge=1.0, - le=100.0, - title='Guidance', - ) - - -class BFLOutputFormat(str, Enum): - jpeg = 'jpeg' - png = 'png' - - -class BFLValidationError(BaseModel): - loc: List[Union[str, int]] = Field(..., title='Location') - msg: str = Field(..., title='Message') - type: str = Field(..., title='Error Type') - - -class Datum2(BaseModel): - image_id: Optional[str] = Field( - None, description='Unique identifier for the generated image' - ) - url: Optional[str] = Field(None, description='URL to access the generated image') - - -class RecraftImageGenerationResponse(BaseModel): - created: int = Field( - ..., description='Unix timestamp when the generation was created' - ) - credits: int = Field(..., description='Number of credits used for the generation') - data: List[Datum2] = Field(..., description='Array of generated image information') - - -class RecraftImageFeatures(BaseModel): - nsfw_score: Optional[float] = None - - -class RecraftTextLayoutItem(BaseModel): - bbox: List[List[float]] - text: str - - -class RecraftImageColor(BaseModel): - rgb: Optional[List[int]] = None - std: Optional[List[float]] = None - weight: Optional[float] = None - - -class RecraftImageStyle(str, Enum): - digital_illustration = 'digital_illustration' - icon = 'icon' - realistic_image = 'realistic_image' - vector_illustration = 'vector_illustration' - - -class RecraftImageSubStyle(str, Enum): - field_2d_art_poster = '2d_art_poster' - field_3d = '3d' - field_80s = '80s' - glow = 'glow' - grain = 'grain' - hand_drawn = 'hand_drawn' - infantile_sketch = 'infantile_sketch' - kawaii = 'kawaii' - pixel_art = 'pixel_art' - psychedelic = 'psychedelic' - seamless = 'seamless' - voxel = 'voxel' - watercolor = 'watercolor' - broken_line = 'broken_line' - colored_outline = 'colored_outline' - colored_shapes = 'colored_shapes' - colored_shapes_gradient = 'colored_shapes_gradient' - doodle_fill = 'doodle_fill' - doodle_offset_fill = 'doodle_offset_fill' - offset_fill = 'offset_fill' - outline = 'outline' - outline_gradient = 'outline_gradient' - uneven_fill = 'uneven_fill' - field_70s = '70s' - cartoon = 'cartoon' - doodle_line_art = 'doodle_line_art' - engraving = 'engraving' - flat_2 = 'flat_2' - kawaii_1 = 'kawaii' - line_art = 'line_art' - linocut = 'linocut' - seamless_1 = 'seamless' - b_and_w = 'b_and_w' - enterprise = 'enterprise' - hard_flash = 'hard_flash' - hdr = 'hdr' - motion_blur = 'motion_blur' - natural_light = 'natural_light' - studio_portrait = 'studio_portrait' - line_circuit = 'line_circuit' - field_2d_art_poster_2 = '2d_art_poster_2' - engraving_color = 'engraving_color' - flat_air_art = 'flat_air_art' - hand_drawn_outline = 'hand_drawn_outline' - handmade_3d = 'handmade_3d' - stickers_drawings = 'stickers_drawings' - plastic = 'plastic' - pictogram = 'pictogram' - - -class RecraftTransformModel(str, Enum): - refm1 = 'refm1' - recraft20b = 'recraft20b' - recraftv2 = 'recraftv2' - recraftv3 = 'recraftv3' - flux1_1pro = 'flux1_1pro' - flux1dev = 'flux1dev' - imagen3 = 'imagen3' - hidream_i1_dev = 'hidream_i1_dev' - - -class RecraftImageFormat(str, Enum): webp = 'webp' - png = 'png' + jpeg = 'jpeg' -class RecraftResponseFormat(str, Enum): +class OpenAIImageEditRequest(BaseModel): + background: Optional[str] = Field( + None, description='Background transparency', examples=['opaque'] + ) + model: str = Field( + ..., description='The model to use for image editing', examples=['gpt-image-1'] + ) + moderation: Optional[Moderation] = Field( + None, description='Content moderation setting', examples=['auto'] + ) + n: Optional[int] = Field( + None, description='The number of images to generate', examples=[1] + ) + output_compression: Optional[int] = Field( + None, description='Compression level for JPEG or WebP (0-100)', examples=[100] + ) + output_format: Optional[OutputFormat1] = Field( + None, description='Format of the output image', examples=['png'] + ) + prompt: str = Field( + ..., + description='A text description of the desired edit', + examples=['Give the rocketship rainbow coloring'], + ) + quality: Optional[str] = Field( + None, description='The quality of the edited image', examples=['low'] + ) + size: Optional[str] = Field( + None, description='Size of the output image', examples=['1024x1024'] + ) + user: Optional[str] = Field( + None, + description='A unique identifier for end-user monitoring', + examples=['user-1234'], + ) + + +class Background(str, Enum): + transparent = 'transparent' + opaque = 'opaque' + + +class Quality(str, Enum): + low = 'low' + medium = 'medium' + high = 'high' + standard = 'standard' + hd = 'hd' + + +class ResponseFormat(str, Enum): url = 'url' b64_json = 'b64_json' -class RecraftImage(BaseModel): - b64_json: Optional[str] = None - features: Optional[RecraftImageFeatures] = None - image_id: UUID - revised_prompt: Optional[str] = None - url: Optional[str] = None - - -class RecraftUserControls(BaseModel): - artistic_level: Optional[int] = None - background_color: Optional[RecraftImageColor] = None - colors: Optional[List[RecraftImageColor]] = None - no_text: Optional[bool] = None - - -class RecraftTextLayout(RootModel[List[RecraftTextLayoutItem]]): - root: List[RecraftTextLayoutItem] - - -class RecraftProcessImageRequest(BaseModel): - image: StrictBytes - image_format: Optional[RecraftImageFormat] = None - response_format: Optional[RecraftResponseFormat] = None - - -class RecraftProcessImageResponse(BaseModel): - created: int - credits: int - image: RecraftImage - - -class RecraftImageToImageRequest(BaseModel): - block_nsfw: Optional[bool] = None - calculate_features: Optional[bool] = None - controls: Optional[RecraftUserControls] = None - image: StrictBytes - image_format: Optional[RecraftImageFormat] = None - model: Optional[RecraftTransformModel] = None - n: Optional[int] = None - negative_prompt: Optional[str] = None - prompt: str - random_seed: Optional[int] = None - response_format: Optional[RecraftResponseFormat] = None - strength: float - style: Optional[RecraftImageStyle] = None - style_id: Optional[UUID] = None - substyle: Optional[RecraftImageSubStyle] = None - text_layout: Optional[RecraftTextLayout] = None - - -class RecraftGenerateImageResponse(BaseModel): - created: int - credits: int - data: List[RecraftImage] - - -class RecraftTransformImageWithMaskRequest(BaseModel): - block_nsfw: Optional[bool] = None - calculate_features: Optional[bool] = None - image: StrictBytes - image_format: Optional[RecraftImageFormat] = None - mask: StrictBytes - model: Optional[RecraftTransformModel] = None - n: Optional[int] = None - negative_prompt: Optional[str] = None - prompt: str - random_seed: Optional[int] = None - response_format: Optional[RecraftResponseFormat] = None - style: Optional[RecraftImageStyle] = None - style_id: Optional[UUID] = None - substyle: Optional[RecraftImageSubStyle] = None - text_layout: Optional[RecraftTextLayout] = None - - -class KlingErrorResponse(BaseModel): - code: int = Field( - ..., - description='- 1000: Authentication failed\n- 1001: Authorization is empty\n- 1002: Authorization is invalid\n- 1003: Authorization is not yet valid\n- 1004: Authorization has expired\n- 1100: Account exception\n- 1101: Account in arrears (postpaid scenario)\n- 1102: Resource pack depleted or expired (prepaid scenario)\n- 1103: Unauthorized access to requested resource\n- 1200: Invalid request parameters\n- 1201: Invalid parameters\n- 1202: Invalid request method\n- 1203: Requested resource does not exist\n- 1300: Trigger platform strategy\n- 1301: Trigger content security policy\n- 1302: API request too frequent\n- 1303: Concurrency/QPS exceeds limit\n- 1304: Trigger IP whitelist policy\n- 5000: Internal server error\n- 5001: Service temporarily unavailable\n- 5002: Server internal timeout\n', - ) - message: str = Field(..., description='Human-readable error message') - request_id: str = Field( - ..., description='Request ID for tracking and troubleshooting' - ) - - -class LumaAspectRatio(str, Enum): - field_1_1 = '1:1' - field_16_9 = '16:9' - field_9_16 = '9:16' - field_4_3 = '4:3' - field_3_4 = '3:4' - field_21_9 = '21:9' - field_9_21 = '9:21' - - -class LumaVideoModel(str, Enum): - ray_2 = 'ray-2' - ray_flash_2 = 'ray-flash-2' - ray_1_6 = 'ray-1-6' - - -class LumaVideoModelOutputResolution1(str, Enum): - field_540p = '540p' - field_720p = '720p' - field_1080p = '1080p' - field_4k = '4k' - - -class LumaVideoModelOutputResolution( - RootModel[Union[LumaVideoModelOutputResolution1, str]] -): - root: Union[LumaVideoModelOutputResolution1, str] - - -class LumaVideoModelOutputDuration1(str, Enum): - field_5s = '5s' - field_9s = '9s' - - -class LumaVideoModelOutputDuration( - RootModel[Union[LumaVideoModelOutputDuration1, str]] -): - root: Union[LumaVideoModelOutputDuration1, str] - - -class LumaImageModel(str, Enum): - photon_1 = 'photon-1' - photon_flash_1 = 'photon-flash-1' - - -class LumaImageRef(BaseModel): - url: Optional[AnyUrl] = Field(None, description='The URL of the image reference') - weight: Optional[float] = Field( - None, description='The weight of the image reference' - ) - - -class LumaImageIdentity(BaseModel): - images: Optional[List[AnyUrl]] = Field( - None, description='The URLs of the image identity' - ) - - -class LumaModifyImageRef(BaseModel): - url: Optional[AnyUrl] = Field(None, description='The URL of the image reference') - weight: Optional[float] = Field( - None, description='The weight of the modify image reference' - ) - - -class Type1(str, Enum): - generation = 'generation' - - -class LumaGenerationReference(BaseModel): - type: Literal['generation'] - id: UUID = Field(..., description='The ID of the generation') - - -class Type2(str, Enum): - image = 'image' - - -class LumaImageReference(BaseModel): - type: Literal['image'] - url: AnyUrl = Field(..., description='The URL of the image') - - -class LumaKeyframe(RootModel[Union[LumaGenerationReference, LumaImageReference]]): - root: Union[LumaGenerationReference, LumaImageReference] = Field( - ..., - description='A keyframe can be either a Generation reference, an Image, or a Video', - discriminator='type', - ) - - -class LumaGenerationType(str, Enum): - video = 'video' - image = 'image' - - -class LumaState(str, Enum): - queued = 'queued' - dreaming = 'dreaming' - completed = 'completed' - failed = 'failed' - - -class LumaAssets(BaseModel): - video: Optional[AnyUrl] = Field(None, description='The URL of the video') - image: Optional[AnyUrl] = Field(None, description='The URL of the image') - progress_video: Optional[AnyUrl] = Field( - None, description='The URL of the progress video' - ) - - -class GenerationType(str, Enum): - video = 'video' - - -class GenerationType1(str, Enum): - image = 'image' - - -class CharacterRef(BaseModel): - identity0: Optional[LumaImageIdentity] = None - - -class LumaImageGenerationRequest(BaseModel): - generation_type: Optional[GenerationType1] = 'image' - model: Optional[LumaImageModel] = 'photon-1' - prompt: Optional[str] = Field(None, description='The prompt of the generation') - aspect_ratio: Optional[LumaAspectRatio] = '16:9' - callback_url: Optional[AnyUrl] = Field( - None, description='The callback URL for the generation' - ) - image_ref: Optional[List[LumaImageRef]] = None - style_ref: Optional[List[LumaImageRef]] = None - character_ref: Optional[CharacterRef] = None - modify_image_ref: Optional[LumaModifyImageRef] = None - - -class GenerationType2(str, Enum): - upscale_video = 'upscale_video' - - -class LumaUpscaleVideoGenerationRequest(BaseModel): - generation_type: Optional[GenerationType2] = 'upscale_video' - resolution: Optional[LumaVideoModelOutputResolution] = None - callback_url: Optional[AnyUrl] = Field( - None, description='The callback URL for the upscale' - ) - - -class GenerationType3(str, Enum): - add_audio = 'add_audio' - - -class LumaAudioGenerationRequest(BaseModel): - generation_type: Optional[GenerationType3] = 'add_audio' - prompt: Optional[str] = Field(None, description='The prompt of the audio') - negative_prompt: Optional[str] = Field( - None, description='The negative prompt of the audio' - ) - callback_url: Optional[AnyUrl] = Field( - None, description='The callback URL for the audio' - ) - - -class LumaError(BaseModel): - detail: Optional[str] = Field(None, description='The error message') - - -class AspectRatio(str, Enum): - field_16_9 = '16:9' - field_4_3 = '4:3' - field_1_1 = '1:1' - field_3_4 = '3:4' - field_9_16 = '9:16' - - -class Duration(int, Enum): - integer_5 = 5 - integer_8 = 8 - - -class Model1(str, Enum): - v3_5 = 'v3.5' - - -class MotionMode(str, Enum): - normal = 'normal' - fast = 'fast' - - -class Quality(str, Enum): - field_360p = '360p' - field_540p = '540p' - field_720p = '720p' - field_1080p = '1080p' - - class Style(str, Enum): - anime = 'anime' - field_3d_animation = '3d_animation' - clay = 'clay' - comic = 'comic' - cyberpunk = 'cyberpunk' + vivid = 'vivid' + natural = 'natural' -class PixverseTextVideoRequest(BaseModel): - aspect_ratio: AspectRatio - duration: Duration - model: Model1 - motion_mode: Optional[MotionMode] = None - negative_prompt: Optional[str] = None - prompt: str - quality: Quality - seed: Optional[int] = None - style: Optional[Style] = None - template_id: Optional[int] = None - water_mark: Optional[bool] = None - - -class Resp(BaseModel): - video_id: Optional[int] = None - - -class PixverseVideoResponse(BaseModel): - ErrCode: Optional[int] = None - ErrMsg: Optional[str] = None - Resp_1: Optional[Resp] = Field(None, alias='Resp') - - -class Resp1(BaseModel): - img_id: Optional[int] = None - - -class PixverseImageUploadResponse(BaseModel): - ErrCode: Optional[int] = None - ErrMsg: Optional[str] = None - Resp: Optional[Resp1] = None - - -class PixverseImageVideoRequest(BaseModel): - img_id: int - model: Model1 - prompt: str - duration: Duration - quality: Quality - motion_mode: Optional[MotionMode] = None - seed: Optional[int] = None - style: Optional[Style] = None - template_id: Optional[int] = None - water_mark: Optional[bool] = None - - -class PixverseTransitionVideoRequest(BaseModel): - first_frame_img: int - last_frame_img: int - model: Model1 - duration: Duration - quality: Quality - motion_mode: MotionMode - seed: int - prompt: str - style: Optional[Style] = None - template_id: Optional[int] = None - water_mark: Optional[bool] = None - - -class Status2(int, Enum): - integer_1 = 1 - integer_5 = 5 - integer_6 = 6 - integer_7 = 7 - integer_8 = 8 - - -class Resp2(BaseModel): - create_time: Optional[str] = None - id: Optional[int] = None - modify_time: Optional[str] = None - negative_prompt: Optional[str] = None - outputHeight: Optional[int] = None - outputWidth: Optional[int] = None - prompt: Optional[str] = None - resolution_ratio: Optional[int] = None - seed: Optional[int] = None - size: Optional[int] = None - status: Optional[Status2] = Field( +class OpenAIImageGenerationRequest(BaseModel): + background: Optional[Background] = Field( + None, description='Background transparency', examples=['opaque'] + ) + model: Optional[str] = Field( + None, description='The model to use for image generation', examples=['dall-e-3'] + ) + moderation: Optional[Moderation] = Field( + None, description='Content moderation setting', examples=['auto'] + ) + n: Optional[int] = Field( None, - description='Video generation status codes:\n* 1 - Generation successful\n* 5 - Generating\n* 6 - Deleted\n* 7 - Contents moderation failed\n* 8 - Generation failed\n', + description='The number of images to generate (1-10). Only 1 supported for dall-e-3.', + examples=[1], ) - style: Optional[str] = None - url: Optional[str] = None - - -class PixverseVideoResultResponse(BaseModel): - ErrCode: Optional[int] = None - ErrMsg: Optional[str] = None - Resp: Optional[Resp2] = None - - -class Image(BaseModel): - bytesBase64Encoded: str - gcsUri: Optional[str] = None - mimeType: Optional[str] = None - - -class Image1(BaseModel): - bytesBase64Encoded: Optional[str] = None - gcsUri: str - mimeType: Optional[str] = None - - -class Instance(BaseModel): - prompt: str = Field(..., description='Text description of the video') - image: Optional[Union[Image, Image1]] = Field( - None, description='Optional image to guide video generation' + output_compression: Optional[int] = Field( + None, description='Compression level for JPEG or WebP (0-100)', examples=[100] ) - - -class PersonGeneration(str, Enum): - ALLOW = 'ALLOW' - BLOCK = 'BLOCK' - - -class Parameters(BaseModel): - aspectRatio: Optional[str] = Field(None, examples=['16:9']) - negativePrompt: Optional[str] = None - personGeneration: Optional[PersonGeneration] = None - sampleCount: Optional[int] = None - seed: Optional[int] = None - storageUri: Optional[str] = Field( - None, description='Optional Cloud Storage URI to upload the video' + output_format: Optional[OutputFormat1] = Field( + None, description='Format of the output image', examples=['png'] ) - durationSeconds: Optional[int] = None - enhancePrompt: Optional[bool] = None - - -class Veo2GenVidRequest(BaseModel): - instances: Optional[List[Instance]] = None - parameters: Optional[Parameters] = None - - -class Veo2GenVidResponse(BaseModel): - name: str = Field( + prompt: str = Field( ..., - description='Operation resource name', - examples=[ - 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8' - ], + description='A text description of the desired image', + examples=['Draw a rocket in front of a blackhole in deep space'], ) - - -class Veo2GenVidPollRequest(BaseModel): - operationName: str = Field( - ..., - description='Full operation name (from predict response)', - examples=[ - 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID' - ], + quality: Optional[Quality] = Field( + None, description='The quality of the generated image', examples=['high'] ) - - -class Video(BaseModel): - gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video') - bytesBase64Encoded: Optional[str] = Field( - None, description='Base64-encoded video content' + response_format: Optional[ResponseFormat] = Field( + None, description='Response format of image data', examples=['b64_json'] ) - mimeType: Optional[str] = Field(None, description='Video MIME type') - - -class Response(BaseModel): - field_type: Optional[str] = Field( + size: Optional[str] = Field( None, - alias='@type', - examples=[ - 'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse' - ], + description='Size of the image (e.g., 1024x1024, 1536x1024, auto)', + examples=['1024x1536'], ) - raiMediaFilteredCount: Optional[int] = Field( - None, description='Count of media filtered by responsible AI policies' + style: Optional[Style] = Field( + None, description='Style of the image (only for dall-e-3)', examples=['vivid'] ) - raiMediaFilteredReasons: Optional[List[str]] = Field( - None, description='Reasons why media was filtered by responsible AI policies' - ) - videos: Optional[List[Video]] = None - - -class Error1(BaseModel): - code: Optional[int] = Field(None, description='Error code') - message: Optional[str] = Field(None, description='Error message') - - -class Veo2GenVidPollResponse(BaseModel): - name: Optional[str] = None - done: Optional[bool] = None - response: Optional[Response] = Field( - None, description='The actual prediction response if done is true' - ) - error: Optional[Error1] = Field( - None, description='Error details if operation failed' + user: Optional[str] = Field( + None, + description='A unique identifier for end-user monitoring', + examples=['user-1234'], ) -class RunwayImageToVideoResponse(BaseModel): - id: Optional[str] = Field(None, description='Task ID') - - -class RunwayTaskStatusEnum(str, Enum): - SUCCEEDED = 'SUCCEEDED' - RUNNING = 'RUNNING' - FAILED = 'FAILED' - PENDING = 'PENDING' - CANCELLED = 'CANCELLED' - THROTTLED = 'THROTTLED' - - -class RunwayModelEnum(str, Enum): - gen4_turbo = 'gen4_turbo' - gen3a_turbo = 'gen3a_turbo' - - -class Position(str, Enum): - first = 'first' - last = 'last' - - -class RunwayPromptImageDetailedObject(BaseModel): - uri: str = Field( - ..., description='A HTTPS URL or data URI containing an encoded image.' - ) - position: Position = Field( - ..., - description="The position of the image in the output video. 'last' is currently supported for gen3a_turbo only.", - ) - - -class RunwayDurationEnum(int, Enum): - integer_5 = 5 - integer_10 = 10 - - -class RunwayAspectRatioEnum(str, Enum): - field_1280_720 = '1280:720' - field_720_1280 = '720:1280' - field_1104_832 = '1104:832' - field_832_1104 = '832:1104' - field_960_960 = '960:960' - field_1584_672 = '1584:672' - field_1280_768 = '1280:768' - field_768_1280 = '768:1280' - - -class RunwayPromptImageObject( - RootModel[Union[str, List[RunwayPromptImageDetailedObject]]] -): - root: Union[str, List[RunwayPromptImageDetailedObject]] = Field( - ..., - description='Image(s) to use for the video generation. Can be a single URI or an array of image objects with positions.', - ) - - -class Datum3(BaseModel): +class Datum2(BaseModel): b64_json: Optional[str] = Field(None, description='Base64 encoded image data') - url: Optional[str] = Field(None, description='URL of the image') revised_prompt: Optional[str] = Field(None, description='Revised prompt') + url: Optional[str] = Field(None, description='URL of the image') class InputTokensDetails(BaseModel): - text_tokens: Optional[int] = None image_tokens: Optional[int] = None + text_tokens: Optional[int] = None class Usage(BaseModel): @@ -1987,143 +1492,204 @@ class Usage(BaseModel): class OpenAIImageGenerationResponse(BaseModel): - data: Optional[List[Datum3]] = None + data: Optional[List[Datum2]] = None usage: Optional[Usage] = None -class Quality3(str, Enum): - low = 'low' - medium = 'medium' - high = 'high' - standard = 'standard' - hd = 'hd' +class OpenAIModels(str, Enum): + gpt_4 = 'gpt-4' + gpt_4_0314 = 'gpt-4-0314' + gpt_4_0613 = 'gpt-4-0613' + gpt_4_32k = 'gpt-4-32k' + gpt_4_32k_0314 = 'gpt-4-32k-0314' + gpt_4_32k_0613 = 'gpt-4-32k-0613' + gpt_4_0125_preview = 'gpt-4-0125-preview' + gpt_4_turbo = 'gpt-4-turbo' + gpt_4_turbo_2024_04_09 = 'gpt-4-turbo-2024-04-09' + gpt_4_turbo_preview = 'gpt-4-turbo-preview' + gpt_4_1106_preview = 'gpt-4-1106-preview' + gpt_4_vision_preview = 'gpt-4-vision-preview' + gpt_3_5_turbo = 'gpt-3.5-turbo' + gpt_3_5_turbo_16k = 'gpt-3.5-turbo-16k' + gpt_3_5_turbo_0301 = 'gpt-3.5-turbo-0301' + gpt_3_5_turbo_0613 = 'gpt-3.5-turbo-0613' + gpt_3_5_turbo_1106 = 'gpt-3.5-turbo-1106' + gpt_3_5_turbo_0125 = 'gpt-3.5-turbo-0125' + gpt_3_5_turbo_16k_0613 = 'gpt-3.5-turbo-16k-0613' + gpt_4_1 = 'gpt-4.1' + gpt_4_1_mini = 'gpt-4.1-mini' + gpt_4_1_nano = 'gpt-4.1-nano' + gpt_4_1_2025_04_14 = 'gpt-4.1-2025-04-14' + gpt_4_1_mini_2025_04_14 = 'gpt-4.1-mini-2025-04-14' + gpt_4_1_nano_2025_04_14 = 'gpt-4.1-nano-2025-04-14' + o1 = 'o1' + o1_mini = 'o1-mini' + o1_preview = 'o1-preview' + o1_pro = 'o1-pro' + o1_2024_12_17 = 'o1-2024-12-17' + o1_preview_2024_09_12 = 'o1-preview-2024-09-12' + o1_mini_2024_09_12 = 'o1-mini-2024-09-12' + o1_pro_2025_03_19 = 'o1-pro-2025-03-19' + o3 = 'o3' + o3_mini = 'o3-mini' + o3_2025_04_16 = 'o3-2025-04-16' + o3_mini_2025_01_31 = 'o3-mini-2025-01-31' + o4_mini = 'o4-mini' + o4_mini_2025_04_16 = 'o4-mini-2025-04-16' + gpt_4o = 'gpt-4o' + gpt_4o_mini = 'gpt-4o-mini' + gpt_4o_2024_11_20 = 'gpt-4o-2024-11-20' + gpt_4o_2024_08_06 = 'gpt-4o-2024-08-06' + gpt_4o_2024_05_13 = 'gpt-4o-2024-05-13' + gpt_4o_mini_2024_07_18 = 'gpt-4o-mini-2024-07-18' + gpt_4o_audio_preview = 'gpt-4o-audio-preview' + gpt_4o_audio_preview_2024_10_01 = 'gpt-4o-audio-preview-2024-10-01' + gpt_4o_audio_preview_2024_12_17 = 'gpt-4o-audio-preview-2024-12-17' + gpt_4o_mini_audio_preview = 'gpt-4o-mini-audio-preview' + gpt_4o_mini_audio_preview_2024_12_17 = 'gpt-4o-mini-audio-preview-2024-12-17' + gpt_4o_search_preview = 'gpt-4o-search-preview' + gpt_4o_mini_search_preview = 'gpt-4o-mini-search-preview' + gpt_4o_search_preview_2025_03_11 = 'gpt-4o-search-preview-2025-03-11' + gpt_4o_mini_search_preview_2025_03_11 = 'gpt-4o-mini-search-preview-2025-03-11' + computer_use_preview = 'computer-use-preview' + computer_use_preview_2025_03_11 = 'computer-use-preview-2025-03-11' + chatgpt_4o_latest = 'chatgpt-4o-latest' -class OutputFormat1(str, Enum): - png = 'png' - webp = 'webp' - jpeg = 'jpeg' +class Reason(str, Enum): + max_output_tokens = 'max_output_tokens' + content_filter = 'content_filter' -class Moderation(str, Enum): - low = 'low' - auto = 'auto' - - -class Background(str, Enum): - transparent = 'transparent' - opaque = 'opaque' - - -class ResponseFormat(str, Enum): - url = 'url' - b64_json = 'b64_json' - - -class Style3(str, Enum): - vivid = 'vivid' - natural = 'natural' - - -class OpenAIImageGenerationRequest(BaseModel): - model: Optional[str] = Field( - None, description='The model to use for image generation', examples=['dall-e-3'] +class IncompleteDetails(BaseModel): + reason: Optional[Reason] = Field( + None, description='The reason why the response is incomplete.' ) - prompt: str = Field( + + +class Object(str, Enum): + response = 'response' + + +class Status6(str, Enum): + completed = 'completed' + failed = 'failed' + in_progress = 'in_progress' + incomplete = 'incomplete' + + +class Type13(str, Enum): + output_audio = 'output_audio' + + +class OutputAudioContent(BaseModel): + data: str = Field(..., description='Base64-encoded audio data') + transcript: str = Field(..., description='Transcript of the audio') + type: Type13 = Field(..., description='The type of output content') + + +class Role4(str, Enum): + assistant = 'assistant' + + +class Type14(str, Enum): + message = 'message' + + +class Type15(str, Enum): + output_text = 'output_text' + + +class OutputTextContent(BaseModel): + text: str = Field(..., description='The text content') + type: Type15 = Field(..., description='The type of output content') + + +class AspectRatio1(RootModel[float]): + root: float = Field( ..., - description='A text description of the desired image', - examples=['Draw a rocket in front of a blackhole in deep space'], - ) - n: Optional[int] = Field( - None, - description='The number of images to generate (1-10). Only 1 supported for dall-e-3.', - examples=[1], - ) - quality: Optional[Quality3] = Field( - None, description='The quality of the generated image', examples=['high'] - ) - size: Optional[str] = Field( - None, - description='Size of the image (e.g., 1024x1024, 1536x1024, auto)', - examples=['1024x1536'], - ) - output_format: Optional[OutputFormat1] = Field( - None, description='Format of the output image', examples=['png'] - ) - output_compression: Optional[int] = Field( - None, description='Compression level for JPEG or WebP (0-100)', examples=[100] - ) - moderation: Optional[Moderation] = Field( - None, description='Content moderation setting', examples=['auto'] - ) - background: Optional[Background] = Field( - None, description='Background transparency', examples=['opaque'] - ) - response_format: Optional[ResponseFormat] = Field( - None, description='Response format of image data', examples=['b64_json'] - ) - style: Optional[Style3] = Field( - None, description='Style of the image (only for dall-e-3)', examples=['vivid'] - ) - user: Optional[str] = Field( - None, - description='A unique identifier for end-user monitoring', - examples=['user-1234'], + description='Aspect ratio (width / height)', + ge=0.4, + le=2.5, + title='Aspectratio', ) -class OpenAIImageEditRequest(BaseModel): - model: str = Field( - ..., description='The model to use for image editing', examples=['gpt-image-1'] - ) - prompt: str = Field( - ..., - description='A text description of the desired edit', - examples=['Give the rocketship rainbow coloring'], - ) - n: Optional[int] = Field( - None, description='The number of images to generate', examples=[1] - ) - quality: Optional[str] = Field( - None, description='The quality of the edited image', examples=['low'] - ) - size: Optional[str] = Field( - None, description='Size of the output image', examples=['1024x1024'] - ) - output_format: Optional[OutputFormat1] = Field( - None, description='Format of the output image', examples=['png'] - ) - output_compression: Optional[int] = Field( - None, description='Compression level for JPEG or WebP (0-100)', examples=[100] - ) - moderation: Optional[Moderation] = Field( - None, description='Content moderation setting', examples=['auto'] - ) - background: Optional[str] = Field( - None, description='Background transparency', examples=['opaque'] - ) - user: Optional[str] = Field( - None, - description='A unique identifier for end-user monitoring', - examples=['user-1234'], - ) +class IngredientsMode(str, Enum): + creative = 'creative' + precise = 'precise' -class CustomerStorageResourceResponse(BaseModel): - download_url: Optional[str] = Field( +class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel): + aspectRatio: Optional[AspectRatio1] = Field( + None, description='Aspect ratio (width / height)', title='Aspectratio' + ) + duration: Optional[int] = Field(5, title='Duration') + images: Optional[List[StrictBytes]] = Field(None, title='Images') + ingredientsMode: IngredientsMode = Field(..., title='Ingredientsmode') + negativePrompt: Optional[str] = Field(None, title='Negativeprompt') + promptText: Optional[str] = Field(None, title='Prompttext') + resolution: Optional[str] = Field('1080p', title='Resolution') + seed: Optional[int] = Field(None, title='Seed') + + +class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel): + image: Optional[StrictBytes] = Field(None, title='Image') + negativePrompt: Optional[str] = Field(None, title='Negativeprompt') + promptText: Optional[str] = Field(None, title='Prompttext') + seed: Optional[int] = Field(None, title='Seed') + video: Optional[StrictBytes] = Field(None, title='Video') + + +class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel): + image: Optional[StrictBytes] = Field(None, title='Image') + modifyRegionMask: Optional[StrictBytes] = Field( None, - description='The signed URL to use for downloading the file from the specified path', + description='A mask image that specifies the region to modify, where the mask is white and the background is black', + title='Modifyregionmask', ) - upload_url: Optional[str] = Field( + modifyRegionRoi: Optional[str] = Field( None, - description='The signed URL to use for uploading the file to the specified path', - ) - expires_at: Optional[datetime] = Field( - None, description='When the signed URL will expire' - ) - existing_file: Optional[bool] = Field( - None, description='Whether an existing file with the same hash was found' + description='Plaintext description of the object / region to modify', + title='Modifyregionroi', ) + negativePrompt: Optional[str] = Field(None, title='Negativeprompt') + promptText: Optional[str] = Field(None, title='Prompttext') + seed: Optional[int] = Field(None, title='Seed') + video: Optional[StrictBytes] = Field(None, title='Video') + + +class PikaDurationEnum(int, Enum): + integer_5 = 5 + integer_10 = 10 + + +class PikaGenerateResponse(BaseModel): + video_id: str = Field(..., title='Video Id') + + +class PikaResolutionEnum(str, Enum): + field_1080p = '1080p' + field_720p = '720p' + + +class PikaStatusEnum(str, Enum): + queued = 'queued' + started = 'started' + finished = 'finished' + + +class PikaValidationError(BaseModel): + loc: List[Union[str, int]] = Field(..., title='Location') + msg: str = Field(..., title='Message') + type: str = Field(..., title='Error Type') + + +class PikaVideoResponse(BaseModel): + id: str = Field(..., title='Id') + progress: Optional[int] = Field(None, title='Progress') + status: PikaStatusEnum + url: Optional[str] = Field(None, title='Url') class Pikaffect(str, Enum): @@ -2145,92 +1711,135 @@ class Pikaffect(str, Enum): Tear = 'Tear' -class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel): - image: Optional[StrictBytes] = Field(None, title='Image') - pikaffect: Optional[Pikaffect] = Field(None, title='Pikaffect') - promptText: Optional[str] = Field(None, title='Prompttext') - negativePrompt: Optional[str] = Field(None, title='Negativeprompt') - seed: Optional[int] = Field(None, title='Seed') +class Resp(BaseModel): + img_id: Optional[int] = None -class PikaGenerateResponse(BaseModel): - video_id: str = Field(..., title='Video Id') +class PixverseImageUploadResponse(BaseModel): + ErrCode: Optional[int] = None + ErrMsg: Optional[str] = None + Resp_1: Optional[Resp] = Field(None, alias='Resp') -class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel): - video: Optional[StrictBytes] = Field(None, title='Video') - image: Optional[StrictBytes] = Field(None, title='Image') - promptText: Optional[str] = Field(None, title='Prompttext') - negativePrompt: Optional[str] = Field(None, title='Negativeprompt') - seed: Optional[int] = Field(None, title='Seed') - - -class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel): - video: Optional[StrictBytes] = Field(None, title='Video') - image: Optional[StrictBytes] = Field(None, title='Image') - promptText: Optional[str] = Field(None, title='Prompttext') - modifyRegionMask: Optional[StrictBytes] = Field( - None, - description='A mask image that specifies the region to modify, where the mask is white and the background is black', - title='Modifyregionmask', - ) - modifyRegionRoi: Optional[str] = Field( - None, - description='Plaintext description of the object / region to modify', - title='Modifyregionroi', - ) - negativePrompt: Optional[str] = Field(None, title='Negativeprompt') - seed: Optional[int] = Field(None, title='Seed') - - -class IngredientsMode(str, Enum): - creative = 'creative' - precise = 'precise' - - -class AspectRatio1(RootModel[float]): - root: float = Field( - ..., - description='Aspect ratio (width / height)', - ge=0.4, - le=2.5, - title='Aspectratio', - ) - - -class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel): - images: Optional[List[StrictBytes]] = Field(None, title='Images') - ingredientsMode: IngredientsMode = Field(..., title='Ingredientsmode') - promptText: Optional[str] = Field(None, title='Prompttext') - negativePrompt: Optional[str] = Field(None, title='Negativeprompt') - seed: Optional[int] = Field(None, title='Seed') - resolution: Optional[str] = Field('1080p', title='Resolution') - duration: Optional[int] = Field(5, title='Duration') - aspectRatio: Optional[AspectRatio1] = Field( - None, description='Aspect ratio (width / height)', title='Aspectratio' - ) - - -class PikaStatusEnum(str, Enum): - queued = 'queued' - started = 'started' - finished = 'finished' - - -class PikaValidationError(BaseModel): - loc: List[Union[str, int]] = Field(..., title='Location') - msg: str = Field(..., title='Message') - type: str = Field(..., title='Error Type') - - -class PikaResolutionEnum(str, Enum): - field_1080p = '1080p' - field_720p = '720p' - - -class PikaDurationEnum(int, Enum): +class Duration(int, Enum): integer_5 = 5 - integer_10 = 10 + integer_8 = 8 + + +class Model1(str, Enum): + v3_5 = 'v3.5' + + +class MotionMode(str, Enum): + normal = 'normal' + fast = 'fast' + + +class Quality1(str, Enum): + field_360p = '360p' + field_540p = '540p' + field_720p = '720p' + field_1080p = '1080p' + + +class Style1(str, Enum): + anime = 'anime' + field_3d_animation = '3d_animation' + clay = 'clay' + comic = 'comic' + cyberpunk = 'cyberpunk' + + +class PixverseImageVideoRequest(BaseModel): + duration: Duration + img_id: int + model: Model1 + motion_mode: Optional[MotionMode] = None + prompt: str + quality: Quality1 + seed: Optional[int] = None + style: Optional[Style1] = None + template_id: Optional[int] = None + water_mark: Optional[bool] = None + + +class AspectRatio2(str, Enum): + field_16_9 = '16:9' + field_4_3 = '4:3' + field_1_1 = '1:1' + field_3_4 = '3:4' + field_9_16 = '9:16' + + +class PixverseTextVideoRequest(BaseModel): + aspect_ratio: AspectRatio2 + duration: Duration + model: Model1 + motion_mode: Optional[MotionMode] = None + negative_prompt: Optional[str] = None + prompt: str + quality: Quality1 + seed: Optional[int] = None + style: Optional[Style1] = None + template_id: Optional[int] = None + water_mark: Optional[bool] = None + + +class PixverseTransitionVideoRequest(BaseModel): + duration: Duration + first_frame_img: int + last_frame_img: int + model: Model1 + motion_mode: MotionMode + prompt: str + quality: Quality1 + seed: int + style: Optional[Style1] = None + template_id: Optional[int] = None + water_mark: Optional[bool] = None + + +class Resp1(BaseModel): + video_id: Optional[int] = None + + +class PixverseVideoResponse(BaseModel): + ErrCode: Optional[int] = None + ErrMsg: Optional[str] = None + Resp: Optional[Resp1] = None + + +class Status7(int, Enum): + integer_1 = 1 + integer_5 = 5 + integer_6 = 6 + integer_7 = 7 + integer_8 = 8 + + +class Resp2(BaseModel): + create_time: Optional[str] = None + id: Optional[int] = None + modify_time: Optional[str] = None + negative_prompt: Optional[str] = None + outputHeight: Optional[int] = None + outputWidth: Optional[int] = None + prompt: Optional[str] = None + resolution_ratio: Optional[int] = None + seed: Optional[int] = None + size: Optional[int] = None + status: Optional[Status7] = Field( + None, + description='Video generation status codes:\n* 1 - Generation successful\n* 5 - Generating\n* 6 - Deleted\n* 7 - Contents moderation failed\n* 8 - Generation failed\n', + ) + style: Optional[str] = None + url: Optional[str] = None + + +class PixverseVideoResultResponse(BaseModel): + ErrCode: Optional[int] = None + ErrMsg: Optional[str] = None + Resp: Optional[Resp2] = None class RgbItem(RootModel[int]): @@ -2241,213 +1850,364 @@ class RGBColor(BaseModel): rgb: List[RgbItem] = Field(..., max_length=3, min_length=3) -class StabilityStabilityClientID(RootModel[str]): - root: str = Field( +class GenerateSummary(str, Enum): + auto = 'auto' + concise = 'concise' + detailed = 'detailed' + + +class Summary(str, Enum): + auto = 'auto' + concise = 'concise' + detailed = 'detailed' + + +class ReasoningEffort(str, Enum): + low = 'low' + medium = 'medium' + high = 'high' + + +class Status8(str, Enum): + in_progress = 'in_progress' + completed = 'completed' + incomplete = 'incomplete' + + +class Type16(str, Enum): + summary_text = 'summary_text' + + +class SummaryItem(BaseModel): + text: str = Field( ..., - description='The name of your application, used to help us communicate app-specific debugging or moderation issues to you.', - examples=['my-awesome-app'], - max_length=256, + description='A short summary of the reasoning used by the model when generating\nthe response.\n', + ) + type: Type16 = Field( + ..., description='The type of the object. Always `summary_text`.\n' ) -class StabilityStabilityClientUserID(RootModel[str]): - root: str = Field( - ..., - description='A unique identifier for your end user. Used to help us communicate user-specific debugging or moderation issues to you. Feel free to obfuscate this value to protect user privacy.', - examples=['DiscordUser#9999'], - max_length=256, - ) +class Type17(str, Enum): + reasoning = 'reasoning' -class StabilityStabilityClientVersion(RootModel[str]): - root: str = Field( - ..., - description='The version of your application, used to help us communicate version-specific debugging or moderation issues to you.', - examples=['1.2.1'], - max_length=256, - ) - - -class Name(str, Enum): - content_moderation = 'content_moderation' - - -class StabilityContentModerationResponse(BaseModel): +class ReasoningItem(BaseModel): id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new) you file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, + ..., description='The unique identifier of the reasoning content.\n' ) - name: Name = Field( - ..., - description='Our content moderation system has flagged some part of your request and subsequently denied it. You were not charged for this request. While this may at times be frustrating, it is necessary to maintain the integrity of our platform and ensure a safe experience for all users. If you would like to provide feedback, please use the [Support Form](https://kb.stability.ai/knowledge-base/kb-tickets/new).', + status: Optional[Status8] = Field( + None, + description='The status of the item. One of `in_progress`, `completed`, or\n`incomplete`. Populated when items are returned via API.\n', ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, + summary: List[SummaryItem] = Field(..., description='Reasoning text contents.\n') + type: Type17 = Field( + ..., description='The type of the object. Always `reasoning`.\n' ) +class Controls(BaseModel): + artistic_level: Optional[int] = Field( + None, + description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity.', + ge=0, + le=5, + ) + background_color: Optional[RGBColor] = None + colors: Optional[List[RGBColor]] = Field( + None, description='An array of preferable colors' + ) + no_text: Optional[bool] = Field(None, description='Do not embed text layouts') + + +class RecraftImageGenerationRequest(BaseModel): + controls: Optional[Controls] = Field( + None, description='The controls for the generated image' + ) + model: str = Field( + ..., description='The model to use for generation (e.g., "recraftv3")' + ) + n: int = Field(..., description='The number of images to generate', ge=1, le=4) + prompt: str = Field( + ..., description='The text prompt describing the image to generate' + ) + size: str = Field( + ..., description='The size of the generated image (e.g., "1024x1024")' + ) + style: Optional[str] = Field( + None, + description='The style to apply to the generated image (e.g., "digital_illustration")', + ) + style_id: Optional[str] = Field( + None, + description='The style ID to apply to the generated image (e.g., "123e4567-e89b-12d3-a456-426614174000"). If style_id is provided, style should not be provided.', + ) + + +class Datum3(BaseModel): + image_id: Optional[str] = Field( + None, description='Unique identifier for the generated image' + ) + url: Optional[str] = Field(None, description='URL to access the generated image') + + +class RecraftImageGenerationResponse(BaseModel): + created: int = Field( + ..., description='Unix timestamp when the generation was created' + ) + credits: int = Field(..., description='Number of credits used for the generation') + data: List[Datum3] = Field(..., description='Array of generated image information') + + class RenderingSpeed(str, Enum): BALANCED = 'BALANCED' TURBO = 'TURBO' QUALITY = 'QUALITY' -class StabilityCreativity(RootModel[float]): - root: float = Field( +class ResponseErrorCode(str, Enum): + server_error = 'server_error' + rate_limit_exceeded = 'rate_limit_exceeded' + invalid_prompt = 'invalid_prompt' + vector_store_timeout = 'vector_store_timeout' + invalid_image = 'invalid_image' + invalid_image_format = 'invalid_image_format' + invalid_base64_image = 'invalid_base64_image' + invalid_image_url = 'invalid_image_url' + image_too_large = 'image_too_large' + image_too_small = 'image_too_small' + image_parse_error = 'image_parse_error' + image_content_policy_violation = 'image_content_policy_violation' + invalid_image_mode = 'invalid_image_mode' + image_file_too_large = 'image_file_too_large' + unsupported_image_media_type = 'unsupported_image_media_type' + empty_image_file = 'empty_image_file' + failed_to_download_image = 'failed_to_download_image' + image_file_not_found = 'image_file_not_found' + + +class Type18(str, Enum): + json_object = 'json_object' + + +class ResponseFormatJsonObject(BaseModel): + type: Type18 = Field( ..., - description='Controls the likelihood of creating additional details not heavily conditioned by the init image.', - ge=0.2, - le=0.5, + description='The type of response format being defined. Always `json_object`.', ) -class StabilityGenerationID(RootModel[str]): - root: str = Field( +class ResponseFormatJsonSchemaSchema(BaseModel): + pass + model_config = ConfigDict( + extra='allow', + ) + + +class Type19(str, Enum): + text = 'text' + + +class ResponseFormatText(BaseModel): + type: Type19 = Field( + ..., description='The type of response format being defined. Always `text`.' + ) + + +class Truncation1(str, Enum): + auto = 'auto' + disabled = 'disabled' + + +class InputTokensDetails1(BaseModel): + cached_tokens: int = Field( ..., - description='The `id` of a generation, typically used for async generations, that can be used to check the status of the generation or retrieve the result.', - examples=['a6dc6c6e20acda010fe14d71f180658f2896ed9b4ec25aa99a6ff06c796987c4'], - max_length=64, - min_length=64, + description='The number of tokens that were retrieved from the cache. \n[More on prompt caching](/docs/guides/prompt-caching).\n', ) -class Mode(str, Enum): - text_to_image = 'text-to-image' - image_to_image = 'image-to-image' +class OutputTokensDetails(BaseModel): + reasoning_tokens: int = Field(..., description='The number of reasoning tokens.') -class AspectRatio2(str, Enum): - field_21_9 = '21:9' - field_16_9 = '16:9' - field_3_2 = '3:2' - field_5_4 = '5:4' - field_1_1 = '1:1' - field_4_5 = '4:5' - field_2_3 = '2:3' - field_9_16 = '9:16' - field_9_21 = '9:21' +class ResponseUsage(BaseModel): + input_tokens: int = Field(..., description='The number of input tokens.') + input_tokens_details: InputTokensDetails1 = Field( + ..., description='A detailed breakdown of the input tokens.' + ) + output_tokens: int = Field(..., description='The number of output tokens.') + output_tokens_details: OutputTokensDetails = Field( + ..., description='A detailed breakdown of the output tokens.' + ) + total_tokens: int = Field(..., description='The total number of tokens used.') -class Model4(str, Enum): - sd3_5_large = 'sd3.5-large' - sd3_5_large_turbo = 'sd3.5-large-turbo' - sd3_5_medium = 'sd3.5-medium' +class Rodin3DCheckStatusRequest(BaseModel): + subscription_key: str = Field( + ..., description='subscription from generate endpoint' + ) -class OutputFormat3(str, Enum): - png = 'png' - jpeg = 'jpeg' +class Rodin3DCheckStatusResponse(BaseModel): + pass -class StylePreset(str, Enum): - enhance = 'enhance' - anime = 'anime' - photographic = 'photographic' - digital_art = 'digital-art' - comic_book = 'comic-book' - fantasy_art = 'fantasy-art' - line_art = 'line-art' - analog_film = 'analog-film' - neon_punk = 'neon-punk' - isometric = 'isometric' - low_poly = 'low-poly' - origami = 'origami' - modeling_compound = 'modeling-compound' - cinematic = 'cinematic' - field_3d_model = '3d-model' - pixel_art = 'pixel-art' - tile_texture = 'tile-texture' +class Rodin3DDownloadRequest(BaseModel): + task_uuid: str = Field(..., description='Task UUID') -class StabilityImageGenrationSD3Request(BaseModel): - prompt: str = Field( +class RodinGenerateJobsData(BaseModel): + subscription_key: Optional[str] = Field(None, description='Subscription Key.') + uuids: Optional[List[str]] = Field(None, description='subjobs uuid.') + + +class RodinMaterialType(str, Enum): + PBR = 'PBR' + Shaded = 'Shaded' + + +class RodinMeshModeType(str, Enum): + Quad = 'Quad' + Raw = 'Raw' + + +class RodinQualityType(str, Enum): + extra_low = 'extra-low' + low = 'low' + medium = 'medium' + high = 'high' + + +class RodinResourceItem(BaseModel): + name: Optional[str] = Field(None, description='File name') + url: Optional[str] = Field(None, description='Download url') + + +class RodinTierType(str, Enum): + Regular = 'Regular' + Sketch = 'Sketch' + Detail = 'Detail' + Smooth = 'Smooth' + + +class RunwayAspectRatioEnum(str, Enum): + field_1280_720 = '1280:720' + field_720_1280 = '720:1280' + field_1104_832 = '1104:832' + field_832_1104 = '832:1104' + field_960_960 = '960:960' + field_1584_672 = '1584:672' + field_1280_768 = '1280:768' + field_768_1280 = '768:1280' + + +class RunwayDurationEnum(int, Enum): + integer_5 = 5 + integer_10 = 10 + + +class RunwayImageToVideoResponse(BaseModel): + id: Optional[str] = Field(None, description='Task ID') + + +class RunwayModelEnum(str, Enum): + gen4_turbo = 'gen4_turbo' + gen3a_turbo = 'gen3a_turbo' + + +class Position(str, Enum): + first = 'first' + last = 'last' + + +class RunwayPromptImageDetailedObject(BaseModel): + position: Position = Field( ..., - description='What you wish to see in the output image. A strong, descriptive prompt that clearly defines\nelements, colors, and subjects will lead to better results.', - max_length=10000, - min_length=1, + description="The position of the image in the output video. 'last' is currently supported for gen3a_turbo only.", ) - mode: Optional[Mode] = Field( - 'text-to-image', - description='Controls whether this is a text-to-image or image-to-image generation, which affects which parameters are required:\n- **text-to-image** requires only the `prompt` parameter\n- **image-to-image** requires the `prompt`, `image`, and `strength` parameters', - title='GenerationMode', + uri: str = Field( + ..., description='A HTTPS URL or data URI containing an encoded image.' ) - image: Optional[StrictBytes] = Field( + + +class RunwayPromptImageObject( + RootModel[Union[str, List[RunwayPromptImageDetailedObject]]] +): + root: Union[str, List[RunwayPromptImageDetailedObject]] = Field( + ..., + description='Image(s) to use for the video generation. Can be a single URI or an array of image objects with positions.', + ) + + +class RunwayTaskStatusEnum(str, Enum): + SUCCEEDED = 'SUCCEEDED' + RUNNING = 'RUNNING' + FAILED = 'FAILED' + PENDING = 'PENDING' + CANCELLED = 'CANCELLED' + THROTTLED = 'THROTTLED' + + +class RunwayTaskStatusResponse(BaseModel): + createdAt: datetime = Field(..., description='Task creation timestamp') + id: str = Field(..., description='Task ID') + output: Optional[List[str]] = Field(None, description='Array of output video URLs') + progress: Optional[float] = Field( None, - description='The image to use as the starting point for the generation.\n\nSupported formats:\n\n\n\n - jpeg\n - png\n - webp\n\nSupported dimensions:\n\n\n\n - Every side must be at least 64 pixels\n\n> **Important:** This parameter is only valid for **image-to-image** requests.', - ) - strength: Optional[float] = Field( - None, - description='Sometimes referred to as _denoising_, this parameter controls how much influence the\n`image` parameter has on the generated image. A value of 0 would yield an image that\nis identical to the input. A value of 1 would be as if you passed in no image at all.\n\n> **Important:** This parameter is only valid for **image-to-image** requests.', + description='Float value between 0 and 1 representing the progress of the task. Only available if status is RUNNING.', ge=0.0, le=1.0, ) - aspect_ratio: Optional[AspectRatio2] = Field( - '1:1', - description='Controls the aspect ratio of the generated image. Defaults to 1:1.\n\n> **Important:** This parameter is only valid for **text-to-image** requests.', - ) - model: Optional[Model4] = Field( - 'sd3.5-large', - description='The model to use for generation.\n\n- `sd3.5-large` requires 6.5 credits per generation\n- `sd3.5-large-turbo` requires 4 credits per generation\n- `sd3.5-medium` requires 3.5 credits per generation\n- As of the April 17, 2025, `sd3-large`, `sd3-large-turbo` and `sd3-medium`\n\n\n\n are re-routed to their `sd3.5-[model version]` equivalent, at the same price.', - ) - seed: Optional[float] = Field( - 0, - description="A specific value that is used to guide the 'randomness' of the generation. (Omit this parameter or pass `0` to use a random seed.)", - ge=0.0, - le=4294967294.0, - ) - output_format: Optional[OutputFormat3] = Field( - 'png', description='Dictates the `content-type` of the generated image.' - ) - style_preset: Optional[StylePreset] = Field( - None, description='Guides the image model towards a particular style.' - ) - negative_prompt: Optional[str] = Field( - None, - description='Keywords of what you **do not** wish to see in the output image.\nThis is an advanced feature.', - max_length=10000, - ) - cfg_scale: Optional[float] = Field( - None, - description='How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt). The _Large_ and _Medium_ models use a default of `4`. The _Turbo_ model uses a default of `1`.', - ge=1.0, - le=10.0, + status: RunwayTaskStatusEnum + + +class RunwayTextToImageAspectRatioEnum(str, Enum): + field_1920_1080 = '1920:1080' + field_1080_1920 = '1080:1920' + field_1024_1024 = '1024:1024' + field_1360_768 = '1360:768' + field_1080_1080 = '1080:1080' + field_1168_880 = '1168:880' + field_1440_1080 = '1440:1080' + field_1080_1440 = '1080:1440' + field_1808_768 = '1808:768' + field_2112_912 = '2112:912' + +class Model4(str, Enum): + gen4_image = 'gen4_image' + + +class ReferenceImage(BaseModel): + uri: Optional[str] = Field( + None, description='A HTTPS URL or data URI containing an encoded image' ) -class FinishReason(str, Enum): - SUCCESS = 'SUCCESS' - CONTENT_FILTERED = 'CONTENT_FILTERED' +class RunwayTextToImageRequest(BaseModel): + model: Model4 = Field(..., description='Model to use for generation') + promptText: str = Field( + ..., description='Text prompt for the image generation', max_length=1000 + ) + ratio: RunwayTextToImageAspectRatioEnum + referenceImages: Optional[List[ReferenceImage]] = Field( + None, description='Array of reference images to guide the generation' + ) -class StabilityImageGenrationSD3Response200(BaseModel): - image: str = Field( +class RunwayTextToImageResponse(BaseModel): + id: Optional[str] = Field(None, description='Task ID') + + +class StabilityError(BaseModel): + errors: List[str] = Field( ..., - description='The generated image, encoded to base64.', - examples=['AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1...'], + description='One or more error messages indicating what went wrong.', + examples=[[{'some-field': 'is required'}]], + min_length=1, ) - seed: Optional[float] = Field( - 0, - description='The seed used as random noise for this generation.', - examples=[343940597], - ge=0.0, - le=4294967294.0, - ) - finish_reason: FinishReason = Field( - ..., - description='The reason the generation finished.\n\n- `SUCCESS` = successful generation.\n- `CONTENT_FILTERED` = successful generation, however the output violated our content moderation\npolicy and has been blurred as a result.', - examples=['SUCCESS'], - ) - - -class StabilityImageGenrationSD3Response400(BaseModel): id: str = Field( ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new) you file, as it will greatly assist us in diagnosing the root cause of the problem.\n', examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], min_length=1, ) @@ -2457,704 +2217,501 @@ class StabilityImageGenrationSD3Response400(BaseModel): examples=['bad_request'], min_length=1, ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) -class StabilityImageGenrationSD3Response413(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) +class Status9(str, Enum): + in_progress = 'in-progress' -class StabilityImageGenrationSD3Response422(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class StabilityImageGenrationSD3Response429(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class StabilityImageGenrationSD3Response500(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class OutputFormat4(str, Enum): - jpeg = 'jpeg' - png = 'png' - webp = 'webp' - - -class StabilityImageGenrationUpscaleConservativeRequest(BaseModel): - image: StrictBytes = Field( - ..., - description='The image you wish to upscale.\n\nSupported Formats:\n- jpeg\n- png\n- webp\n\nValidation Rules:\n- Every side must be at least 64 pixels\n- Total pixel count must be between 4,096 and 9,437,184 pixels\n- The aspect ratio must be between 1:2.5 and 2.5:1', - examples=['./some/image.png'], - ) - prompt: str = Field( - ..., - description="What you wish to see in the output image. A strong, descriptive prompt that clearly defines\nelements, colors, and subjects will lead to better results.\n\nTo control the weight of a given word use the format `(word:weight)`,\nwhere `word` is the word you'd like to control the weight of and `weight`\nis a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`\nwould convey a sky that was blue and green, but more green than blue.", - max_length=10000, - min_length=1, - ) - negative_prompt: Optional[str] = Field( - None, - description='A blurb of text describing what you **do not** wish to see in the output image.\nThis is an advanced feature.', - max_length=10000, - ) - seed: Optional[float] = Field( - 0, - description="A specific value that is used to guide the 'randomness' of the generation. (Omit this parameter or pass `0` to use a random seed.)", - ge=0.0, - le=4294967294.0, - ) - output_format: Optional[OutputFormat4] = Field( - 'png', description='Dictates the `content-type` of the generated image.' - ) - creativity: Optional[StabilityCreativity] = Field( - default_factory=lambda: StabilityCreativity.model_validate(0.35) - ) - - -class StabilityImageGenrationUpscaleConservativeResponse200(BaseModel): - image: str = Field( - ..., - description='The generated image, encoded to base64.', - examples=['AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1...'], - ) - seed: Optional[float] = Field( - 0, - description='The seed used as random noise for this generation.', - examples=[343940597], - ge=0.0, - le=4294967294.0, - ) - finish_reason: FinishReason = Field( - ..., - description='The reason the generation finished.\n\n- `SUCCESS` = successful generation.\n- `CONTENT_FILTERED` = successful generation, however the output violated our content moderation\npolicy and has been blurred as a result.', - examples=['SUCCESS'], - ) - - -class StabilityImageGenrationUpscaleConservativeResponse400(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class StabilityImageGenrationUpscaleConservativeResponse413(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class StabilityImageGenrationUpscaleConservativeResponse422(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class StabilityImageGenrationUpscaleConservativeResponse429(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class StabilityImageGenrationUpscaleConservativeResponse500(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class StabilityImageGenrationUpscaleCreativeRequest(BaseModel): - image: StrictBytes = Field( - ..., - description='The image you wish to upscale.\n\nSupported Formats:\n- jpeg\n- png\n- webp\n\nValidation Rules:\n- Every side must be at least 64 pixels\n- Total pixel count must be between 4,096 and 1,048,576 pixels', - examples=['./some/image.png'], - ) - prompt: str = Field( - ..., - description="What you wish to see in the output image. A strong, descriptive prompt that clearly defines\nelements, colors, and subjects will lead to better results.\n\nTo control the weight of a given word use the format `(word:weight)`,\nwhere `word` is the word you'd like to control the weight of and `weight`\nis a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`\nwould convey a sky that was blue and green, but more green than blue.", - max_length=10000, - min_length=1, - ) - negative_prompt: Optional[str] = Field( - None, - description='A blurb of text describing what you **do not** wish to see in the output image.\nThis is an advanced feature.', - max_length=10000, - ) - output_format: Optional[OutputFormat4] = Field( - 'png', description='Dictates the `content-type` of the generated image.' - ) - seed: Optional[float] = Field( - 0, - description="A specific value that is used to guide the 'randomness' of the generation. (Omit this parameter or pass `0` to use a random seed.)", - ge=0.0, - le=4294967294.0, - ) - creativity: Optional[float] = Field( - 0.3, - description='Indicates how creative the model should be when upscaling an image.\nHigher values will result in more details being added to the image during upscaling.', - ge=0.1, - le=0.5, - ) - style_preset: Optional[StylePreset] = Field( - None, description='Guides the image model towards a particular style.' - ) - - -class StabilityImageGenrationUpscaleCreativeResponse200(BaseModel): - id: StabilityGenerationID - - -class StabilityImageGenrationUpscaleCreativeResponse400(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class StabilityImageGenrationUpscaleCreativeResponse413(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class StabilityImageGenrationUpscaleCreativeResponse422(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class StabilityImageGenrationUpscaleCreativeResponse429(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class StabilityImageGenrationUpscaleCreativeResponse500(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class StabilityImageGenrationUpscaleFastRequest(BaseModel): - image: StrictBytes = Field( - ..., - description='The image you wish to upscale.\n\nSupported Formats:\n- jpeg\n- png\n- webp\n\nValidation Rules:\n- Width must be between 32 and 1,536 pixels\n- Height must be between 32 and 1,536 pixels\n- Total pixel count must be between 1,024 and 1,048,576 pixels', - examples=['./some/image.png'], - ) - output_format: Optional[OutputFormat4] = Field( - 'png', description='Dictates the `content-type` of the generated image.' - ) - - -class StabilityImageGenrationUpscaleFastResponse200(BaseModel): - image: str = Field( - ..., - description='The generated image, encoded to base64.', - examples=['AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1...'], - ) - seed: Optional[float] = Field( - 0, - description='The seed used as random noise for this generation.', - examples=[343940597], - ge=0.0, - le=4294967294.0, - ) - finish_reason: FinishReason = Field( - ..., - description='The reason the generation finished.\n\n- `SUCCESS` = successful generation.\n- `CONTENT_FILTERED` = successful generation, however the output violated our content moderation\npolicy and has been blurred as a result.', - examples=['SUCCESS'], - ) - - -class StabilityImageGenrationUpscaleFastResponse400(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class StabilityImageGenrationUpscaleFastResponse413(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class StabilityImageGenrationUpscaleFastResponse422(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class StabilityImageGenrationUpscaleFastResponse429(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class StabilityImageGenrationUpscaleFastResponse500(BaseModel): - id: str = Field( - ..., - description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', - examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], - min_length=1, - ) - name: str = Field( - ..., - description='Short-hand name for an error, useful for discriminating between errors with the same status code.', - examples=['bad_request'], - min_length=1, - ) - errors: List[str] = Field( - ..., - description='One or more error messages indicating what went wrong.', - examples=[['some-field: is required']], - min_length=1, - ) - - -class ActionJobResult(BaseModel): - id: Optional[UUID] = Field(None, description='Unique identifier for the job result') - workflow_name: Optional[str] = Field(None, description='Name of the workflow') - operating_system: Optional[str] = Field(None, description='Operating system used') - python_version: Optional[str] = Field(None, description='PyTorch version used') - pytorch_version: Optional[str] = Field(None, description='PyTorch version used') - action_run_id: Optional[str] = Field( - None, description='Identifier of the run this result belongs to' - ) - action_job_id: Optional[str] = Field( - None, description='Identifier of the job this result belongs to' - ) - cuda_version: Optional[str] = Field(None, description='CUDA version used') - branch_name: Optional[str] = Field( - None, description='Name of the relevant git branch' - ) - commit_hash: Optional[str] = Field(None, description='The hash of the commit') - commit_id: Optional[str] = Field(None, description='The ID of the commit') - commit_time: Optional[int] = Field( - None, description='The Unix timestamp when the commit was made' - ) - commit_message: Optional[str] = Field(None, description='The message of the commit') - comfy_run_flags: Optional[str] = Field( - None, description='The comfy run flags. E.g. `--low-vram`' - ) - git_repo: Optional[str] = Field(None, description='The repository name') - pr_number: Optional[str] = Field(None, description='The pull request number') - start_time: Optional[int] = Field( - None, description='The start time of the job as a Unix timestamp.' - ) - end_time: Optional[int] = Field( - None, description='The end time of the job as a Unix timestamp.' - ) - avg_vram: Optional[int] = Field( - None, description='The average VRAM used by the job' - ) - peak_vram: Optional[int] = Field(None, description='The peak VRAM used by the job') - job_trigger_user: Optional[str] = Field( - None, description='The user who triggered the job.' - ) - author: Optional[str] = Field(None, description='The author of the commit') - machine_stats: Optional[MachineStats] = None - status: Optional[WorkflowRunStatus] = None - storage_file: Optional[StorageFile] = None - - -class Publisher(BaseModel): - name: Optional[str] = None +class StabilityGetResultResponse202(BaseModel): id: Optional[str] = Field( + None, description='The ID of the generation result.', examples=[1234567890] + ) + status: Optional[Status9] = None + + +class Type20(str, Enum): + json_schema = 'json_schema' + + +class TextResponseFormatJsonSchema(BaseModel): + description: Optional[str] = Field( None, - description="The unique identifier for the publisher. It's akin to a username. Should be lowercase.", + description='A description of what the response format is for, used by the model to\ndetermine how to respond in the format.\n', ) - description: Optional[str] = None - website: Optional[str] = None - support: Optional[str] = None - source_code_repo: Optional[str] = None - logo: Optional[str] = Field(None, description="URL to the publisher's logo.") - createdAt: Optional[datetime] = Field( - None, description='The date and time the publisher was created.' + name: str = Field( + ..., + description='The name of the response format. Must be a-z, A-Z, 0-9, or contain\nunderscores and dashes, with a maximum length of 64.\n', ) - members: Optional[List[PublisherMember]] = Field( - None, description='A list of members in the publisher.' + schema_: ResponseFormatJsonSchemaSchema = Field(..., alias='schema') + strict: Optional[bool] = Field( + False, + description='Whether to enable strict schema adherence when generating the output.\nIf set to true, the model will always follow the exact schema defined\nin the `schema` field. Only a subset of JSON Schema is supported when\n`strict` is `true`. To learn more, read the [Structured Outputs\nguide](/docs/guides/structured-outputs).\n', ) - status: Optional[PublisherStatus] = Field( - None, description='The status of the publisher.' + type: Type20 = Field( + ..., + description='The type of response format being defined. Always `json_schema`.', ) -class NodeVersion(BaseModel): - id: Optional[str] = None - version: Optional[str] = Field( +class Type21(str, Enum): + function = 'function' + + +class ToolChoiceFunction(BaseModel): + name: str = Field(..., description='The name of the function to call.') + type: Type21 = Field( + ..., description='For function calling, the type is always `function`.' + ) + + +class ToolChoiceOptions(str, Enum): + none = 'none' + auto = 'auto' + required = 'required' + + +class Type22(str, Enum): + file_search = 'file_search' + web_search_preview = 'web_search_preview' + computer_use_preview = 'computer_use_preview' + web_search_preview_2025_03_11 = 'web_search_preview_2025_03_11' + + +class ToolChoiceTypes(BaseModel): + type: Type22 = Field( + ..., + description='The type of hosted tool the model should to use. Learn more about\n[built-in tools](/docs/guides/tools).\n\nAllowed values are:\n- `file_search`\n- `web_search_preview`\n- `computer_use_preview`\n', + ) + + +class TripoAnimation(str, Enum): + preset_idle = 'preset:idle' + preset_walk = 'preset:walk' + preset_climb = 'preset:climb' + preset_jump = 'preset:jump' + preset_run = 'preset:run' + preset_slash = 'preset:slash' + preset_shoot = 'preset:shoot' + preset_hurt = 'preset:hurt' + preset_fall = 'preset:fall' + preset_turn = 'preset:turn' + + +class TripoBalance(BaseModel): + balance: float + frozen: float + + +class TripoConvertFormat(str, Enum): + GLTF = 'GLTF' + USDZ = 'USDZ' + FBX = 'FBX' + OBJ = 'OBJ' + STL = 'STL' + field_3MF = '3MF' + + +class Code(int, Enum): + integer_1001 = 1001 + integer_2000 = 2000 + integer_2001 = 2001 + integer_2002 = 2002 + integer_2003 = 2003 + integer_2004 = 2004 + integer_2006 = 2006 + integer_2007 = 2007 + integer_2008 = 2008 + integer_2010 = 2010 + + +class TripoErrorResponse(BaseModel): + code: Code + message: str + suggestion: str + + +class TripoImageToModel(str, Enum): + image_to_model = 'image_to_model' + + +class TripoModelStyle(str, Enum): + person_person2cartoon = 'person:person2cartoon' + animal_venom = 'animal:venom' + object_clay = 'object:clay' + object_steampunk = 'object:steampunk' + object_christmas = 'object:christmas' + object_barbie = 'object:barbie' + gold = 'gold' + ancient_bronze = 'ancient_bronze' + + +class TripoModelVersion(str, Enum): + V2_5 = 'v2.5-20250123' + V2_0 = 'v2.0-20240919' + V1_4 = 'v1.4-20240625' + + +class TripoMultiviewMode(str, Enum): + LEFT = 'LEFT' + RIGHT = 'RIGHT' + + +class TripoMultiviewToModel(str, Enum): + multiview_to_model = 'multiview_to_model' + + +class TripoOrientation(str, Enum): + align_image = 'align_image' + default = 'default' + + +class TripoResponseSuccessCode(RootModel[int]): + root: int = Field( + ..., + description='Standard success code for Tripo API responses. Typically 0 for success.', + examples=[0], + ) + + +class TripoSpec(str, Enum): + mixamo = 'mixamo' + tripo = 'tripo' + + +class TripoStandardFormat(str, Enum): + glb = 'glb' + fbx = 'fbx' + + +class TripoStylizeOptions(str, Enum): + lego = 'lego' + voxel = 'voxel' + voronoi = 'voronoi' + minecraft = 'minecraft' + + +class Code1(int, Enum): + integer_0 = 0 + + +class Data8(BaseModel): + task_id: str = Field(..., description='used for getTask') + + +class TripoSuccessTask(BaseModel): + code: Code1 + data: Data8 + + +class Topology(str, Enum): + bip = 'bip' + quad = 'quad' + + +class Output(BaseModel): + base_model: Optional[str] = None + model: Optional[str] = None + pbr_model: Optional[str] = None + rendered_image: Optional[str] = None + riggable: Optional[bool] = None + topology: Optional[Topology] = None + + +class Status10(str, Enum): + queued = 'queued' + running = 'running' + success = 'success' + failed = 'failed' + cancelled = 'cancelled' + unknown = 'unknown' + banned = 'banned' + expired = 'expired' + + +class TripoTask(BaseModel): + create_time: int + input: Dict[str, Any] + output: Output + progress: int = Field(..., ge=0, le=100) + status: Status10 + task_id: str + type: str + + +class TripoTextToModel(str, Enum): + text_to_model = 'text_to_model' + + +class TripoTextureAlignment(str, Enum): + original_image = 'original_image' + geometry = 'geometry' + + +class TripoTextureFormat(str, Enum): + BMP = 'BMP' + DPX = 'DPX' + HDR = 'HDR' + JPEG = 'JPEG' + OPEN_EXR = 'OPEN_EXR' + PNG = 'PNG' + TARGA = 'TARGA' + TIFF = 'TIFF' + WEBP = 'WEBP' + + +class TripoTextureQuality(str, Enum): + standard = 'standard' + detailed = 'detailed' + + +class TripoTopology(str, Enum): + bip = 'bip' + quad = 'quad' + + +class TripoTypeAnimatePrerigcheck(str, Enum): + animate_prerigcheck = 'animate_prerigcheck' + + +class TripoTypeAnimateRetarget(str, Enum): + animate_retarget = 'animate_retarget' + + +class TripoTypeAnimateRig(str, Enum): + animate_rig = 'animate_rig' + + +class TripoTypeConvertModel(str, Enum): + convert_model = 'convert_model' + + +class TripoTypeRefineModel(str, Enum): + refine_model = 'refine_model' + + +class TripoTypeStylizeModel(str, Enum): + stylize_model = 'stylize_model' + + +class TripoTypeTextureModel(str, Enum): + texture_model = 'texture_model' + + +class Veo2GenVidPollRequest(BaseModel): + operationName: str = Field( + ..., + description='Full operation name (from predict response)', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID' + ], + ) + + +class Error(BaseModel): + code: Optional[int] = Field(None, description='Error code') + message: Optional[str] = Field(None, description='Error message') + + +class Video(BaseModel): + bytesBase64Encoded: Optional[str] = Field( + None, description='Base64-encoded video content' + ) + gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video') + mimeType: Optional[str] = Field(None, description='Video MIME type') + + +class Response(BaseModel): + field_type: Optional[str] = Field( None, - description='The version identifier, following semantic versioning. Must be unique for the node.', + alias='@type', + examples=[ + 'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse' + ], ) - createdAt: Optional[datetime] = Field( - None, description='The date and time the version was created.' + raiMediaFilteredCount: Optional[int] = Field( + None, description='Count of media filtered by responsible AI policies' ) - changelog: Optional[str] = Field( - None, description='Summary of changes made in this version' + raiMediaFilteredReasons: Optional[List[str]] = Field( + None, description='Reasons why media was filtered by responsible AI policies' ) - dependencies: Optional[List[str]] = Field( - None, description='A list of pip dependencies required by the node.' + videos: Optional[List[Video]] = None + + +class Veo2GenVidPollResponse(BaseModel): + done: Optional[bool] = None + error: Optional[Error] = Field( + None, description='Error details if operation failed' ) - downloadUrl: Optional[str] = Field( - None, description='[Output Only] URL to download this version of the node' - ) - deprecated: Optional[bool] = Field( - None, description='Indicates if this version is deprecated.' - ) - status: Optional[NodeVersionStatus] = Field( - None, description='The status of the node version.' - ) - status_reason: Optional[str] = Field( - None, description='The reason for the status change.' - ) - node_id: Optional[str] = Field( - None, description='The unique identifier of the node.' - ) - comfy_node_extract_status: Optional[str] = Field( - None, description='The status of comfy node extraction process.' + name: Optional[str] = None + response: Optional[Response] = Field( + None, description='The actual prediction response if done is true' ) -class IdeogramV3Request(BaseModel): - prompt: str = Field(..., description='The text prompt for image generation') - seed: Optional[int] = Field( - None, description='Seed value for reproducible generation' +class Image(BaseModel): + bytesBase64Encoded: str + gcsUri: Optional[str] = None + mimeType: Optional[str] = None + + +class Image1(BaseModel): + bytesBase64Encoded: Optional[str] = None + gcsUri: str + mimeType: Optional[str] = None + + +class Instance(BaseModel): + image: Optional[Union[Image, Image1]] = Field( + None, description='Optional image to guide video generation' ) - resolution: Optional[str] = Field( - None, description='Image resolution in format WxH', examples=['1280x800'] + prompt: str = Field(..., description='Text description of the video') + + +class PersonGeneration1(str, Enum): + ALLOW = 'ALLOW' + BLOCK = 'BLOCK' + + +class Parameters(BaseModel): + aspectRatio: Optional[str] = Field(None, examples=['16:9']) + durationSeconds: Optional[int] = None + enhancePrompt: Optional[bool] = None + negativePrompt: Optional[str] = None + personGeneration: Optional[PersonGeneration1] = None + sampleCount: Optional[int] = None + seed: Optional[int] = None + storageUri: Optional[str] = Field( + None, description='Optional Cloud Storage URI to upload the video' ) - aspect_ratio: Optional[str] = Field( - None, description='Aspect ratio in format WxH', examples=['1x3'] + + +class Veo2GenVidRequest(BaseModel): + instances: Optional[List[Instance]] = None + parameters: Optional[Parameters] = None + + +class Veo2GenVidResponse(BaseModel): + name: str = Field( + ..., + description='Operation resource name', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8' + ], ) - rendering_speed: RenderingSpeed - magic_prompt: Optional[MagicPrompt] = Field( - None, description='Whether to enable magic prompt enhancement' + + +class SearchContextSize(str, Enum): + low = 'low' + medium = 'medium' + high = 'high' + + +class Type23(str, Enum): + web_search_preview = 'web_search_preview' + web_search_preview_2025_03_11 = 'web_search_preview_2025_03_11' + + +class WebSearchPreviewTool(BaseModel): + search_context_size: Optional[SearchContextSize] = Field( + None, + description='High level guidance for the amount of context window space to use for the search. One of `low`, `medium`, or `high`. `medium` is the default.', ) - negative_prompt: Optional[str] = Field( - None, description='Text prompt specifying what to avoid in the generation' + type: Literal['WebSearchPreviewTool'] = Field( + ..., + description='The type of the web search tool. One of `web_search_preview` or `web_search_preview_2025_03_11`.', ) - num_images: Optional[int] = Field( - None, description='Number of images to generate', ge=1 + + +class Status11(str, Enum): + in_progress = 'in_progress' + searching = 'searching' + completed = 'completed' + failed = 'failed' + + +class Type24(str, Enum): + web_search_call = 'web_search_call' + + +class WebSearchToolCall(BaseModel): + id: str = Field(..., description='The unique ID of the web search tool call.\n') + status: Status11 = Field( + ..., description='The status of the web search tool call.\n' ) - color_palette: Optional[ColorPalette] = None - style_codes: Optional[List[StyleCode]] = Field( - None, description='Array of style codes in hexadecimal format' + type: Type24 = Field( + ..., + description='The type of the web search tool call. Always `web_search_call`.\n', ) - style_type: Optional[StyleType] = Field( - None, description='The type of style to apply' + + +class CreateModelResponseProperties(ModelResponseProperties): + pass + + +class GeminiInlineData(BaseModel): + data: Optional[str] = Field( + None, + description='The base64 encoding of the image, PDF, or video to include inline in the prompt. When including media inline, you must also specify the media type (mimeType) of the data. Size limit: 20MB\n', ) - style_reference_images: Optional[List[str]] = Field( - None, description='Array of reference image URLs or identifiers' + mimeType: Optional[GeminiMimeType] = None + + +class GeminiPart(BaseModel): + inlineData: Optional[GeminiInlineData] = None + text: Optional[str] = Field( + None, + description='A text prompt or code snippet.', + examples=['Write a story about a robot learning to paint'], + ) + + +class GeminiPromptFeedback(BaseModel): + blockReason: Optional[str] = None + blockReasonMessage: Optional[str] = None + safetyRatings: Optional[List[GeminiSafetyRating]] = None + + +class GeminiSafetySetting(BaseModel): + category: GeminiSafetyCategory + threshold: GeminiSafetyThreshold + + +class GeminiSystemInstructionContent(BaseModel): + parts: List[GeminiTextPart] = Field( + ..., + description='A list of ordered parts that make up a single message. Different parts may have different IANA MIME types. For limits on the inputs, such as the maximum number of tokens or the number of images, see the model specifications on the Google models page.\n', + ) + role: Role1 = Field( + ..., + description='The identity of the entity that creates the message. The following values are supported: user: This indicates that the message is sent by a real person, typically a user-generated message. model: This indicates that the message is generated by the model. The model value is used to insert messages from the model into the conversation during multi-turn conversations. For non-multi-turn conversations, this field can be left blank or unset.\n', + examples=['user'], ) class IdeogramV3EditRequest(BaseModel): + color_palette: Optional[IdeogramColorPalette] = None image: Optional[StrictBytes] = Field( None, description='The image being edited (max size 10MB); only JPEG, WebP and PNG formats are supported at this time.', ) - mask: Optional[StrictBytes] = Field( - None, - description='A black and white image of the same size as the image being edited (max size 10MB). Black regions in the mask should match up with the regions of the image that you would like to edit; only JPEG, WebP and PNG formats are supported at this time.', - ) - prompt: str = Field( - ..., description='The prompt used to describe the edited result.' - ) magic_prompt: Optional[str] = Field( None, description='Determine if MagicPrompt should be used in generating the request or not.', ) + mask: Optional[StrictBytes] = Field( + None, + description='A black and white image of the same size as the image being edited (max size 10MB). Black regions in the mask should match up with the regions of the image that you would like to edit; only JPEG, WebP and PNG formats are supported at this time.', + ) num_images: Optional[int] = Field( None, description='The number of images to generate.' ) - seed: Optional[int] = Field( - None, description='Random seed. Set for reproducible generation.' + prompt: str = Field( + ..., description='The prompt used to describe the edited result.' ) rendering_speed: RenderingSpeed - color_palette: Optional[IdeogramColorPalette] = Field( - None, - description='A color palette for generation, must EITHER be specified via one of the presets (name) or explicitly via hexadecimal representations of the color with optional weights (members). Not supported by V_1, V_1_TURBO, V_2A and V_2A_TURBO models.', + seed: Optional[int] = Field( + None, description='Random seed. Set for reproducible generation.' ) style_codes: Optional[List[StyleCode]] = Field( None, @@ -3166,34 +2723,102 @@ class IdeogramV3EditRequest(BaseModel): ) -class KlingCameraControl(BaseModel): - type: Optional[KlingCameraControlType] = None - config: Optional[KlingCameraConfig] = None - - -class KlingText2VideoRequest(BaseModel): - model_name: Optional[KlingVideoGenModelName] = 'kling-v2-master' - prompt: Optional[str] = Field( - None, description='Positive text prompt', max_length=2500 +class IdeogramV3Request(BaseModel): + aspect_ratio: Optional[str] = Field( + None, description='Aspect ratio in format WxH', examples=['1x3'] + ) + color_palette: Optional[ColorPalette] = None + magic_prompt: Optional[MagicPrompt2] = Field( + None, description='Whether to enable magic prompt enhancement' ) negative_prompt: Optional[str] = Field( - None, description='Negative text prompt', max_length=2500 + None, description='Text prompt specifying what to avoid in the generation' ) - cfg_scale: Optional[KlingVideoGenCfgScale] = Field( - default_factory=lambda: KlingVideoGenCfgScale.model_validate(0.5) + num_images: Optional[int] = Field( + None, description='Number of images to generate', ge=1 ) + prompt: str = Field(..., description='The text prompt for image generation') + rendering_speed: RenderingSpeed + resolution: Optional[str] = Field( + None, description='Image resolution in format WxH', examples=['1280x800'] + ) + seed: Optional[int] = Field( + None, description='Seed value for reproducible generation' + ) + style_codes: Optional[List[StyleCode]] = Field( + None, description='Array of style codes in hexadecimal format' + ) + style_reference_images: Optional[List[str]] = Field( + None, description='Array of reference image URLs or identifiers' + ) + style_type: Optional[StyleType1] = Field( + None, description='The type of style to apply' + ) + + +class ImagenGenerateImageResponse(BaseModel): + predictions: Optional[List[ImagenImagePrediction]] = None + + +class ImagenImageGenerationParameters(BaseModel): + addWatermark: Optional[bool] = None + aspectRatio: Optional[AspectRatio] = None + enhancePrompt: Optional[bool] = None + includeRaiReason: Optional[bool] = None + includeSafetyAttributes: Optional[bool] = None + outputOptions: Optional[ImagenOutputOptions] = None + personGeneration: Optional[PersonGeneration] = None + safetySetting: Optional[SafetySetting] = None + sampleCount: Optional[int] = Field(None, ge=1, le=4) + seed: Optional[int] = None + storageUri: Optional[AnyUrl] = None + + +class InputContent( + RootModel[Union[InputTextContent, InputImageContent, InputFileContent]] +): + root: Union[InputTextContent, InputImageContent, InputFileContent] + + +class InputMessageContentList(RootModel[List[InputContent]]): + root: List[InputContent] = Field( + ..., + description='A list of one or many input items to the model, containing different content \ntypes.\n', + title='Input item content list', + ) + + +class KlingCameraControl(BaseModel): + config: Optional[KlingCameraConfig] = None + type: Optional[KlingCameraControlType] = None + + +class KlingDualCharacterEffectInput(BaseModel): + duration: KlingVideoGenDuration + images: KlingDualCharacterImages mode: Optional[KlingVideoGenMode] = 'std' - camera_control: Optional[KlingCameraControl] = None - aspect_ratio: Optional[KlingVideoGenAspectRatio] = '16:9' - duration: Optional[KlingVideoGenDuration] = '5' - callback_url: Optional[AnyUrl] = Field( - None, description='The callback notification address' - ) - external_task_id: Optional[str] = Field(None, description='Customized Task ID') + model_name: Optional[KlingCharacterEffectModelName] = 'kling-v1' class KlingImage2VideoRequest(BaseModel): - model_name: Optional[KlingVideoGenModelName] = 'kling-v2-master' + aspect_ratio: Optional[KlingVideoGenAspectRatio] = '16:9' + callback_url: Optional[AnyUrl] = Field( + None, + description='The callback notification address. Server will notify when the task status changes.', + ) + camera_control: Optional[KlingCameraControl] = None + cfg_scale: Optional[KlingVideoGenCfgScale] = Field( + default_factory=lambda: KlingVideoGenCfgScale.model_validate(0.5) + ) + duration: Optional[KlingVideoGenDuration] = '5' + dynamic_masks: Optional[List[DynamicMask]] = Field( + None, + description='Dynamic Brush Configuration List (up to 6 groups). For 5-second videos, trajectory length must not exceed 77 coordinates.', + ) + external_task_id: Optional[str] = Field( + None, + description='Customized Task ID. Must be unique within a single user account.', + ) image: Optional[str] = Field( None, description='Reference Image - URL or Base64 encoded string, cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1. Base64 should not include data:image prefix.', @@ -3202,35 +2827,168 @@ class KlingImage2VideoRequest(BaseModel): None, description='Reference Image - End frame control. URL or Base64 encoded string, cannot exceed 10MB, resolution not less than 300*300px. Base64 should not include data:image prefix.', ) - prompt: Optional[str] = Field( - None, description='Positive text prompt', max_length=2500 - ) + mode: Optional[KlingVideoGenMode] = 'std' + model_name: Optional[KlingVideoGenModelName] = 'kling-v2-master' negative_prompt: Optional[str] = Field( None, description='Negative text prompt', max_length=2500 ) - cfg_scale: Optional[KlingVideoGenCfgScale] = Field( - default_factory=lambda: KlingVideoGenCfgScale.model_validate(0.5) + prompt: Optional[str] = Field( + None, description='Positive text prompt', max_length=2500 ) - mode: Optional[KlingVideoGenMode] = 'std' static_mask: Optional[str] = Field( None, description='Static Brush Application Area (Mask image created by users using the motion brush). The aspect ratio must match the input image.', ) - dynamic_masks: Optional[List[DynamicMask]] = Field( + + +class TaskResult(BaseModel): + videos: Optional[List[KlingVideoResult]] = None + + +class Data(BaseModel): + created_at: Optional[int] = Field(None, description='Task creation time') + task_id: Optional[str] = Field(None, description='Task ID') + task_info: Optional[TaskInfo] = None + task_result: Optional[TaskResult] = None + task_status: Optional[KlingTaskStatus] = None + updated_at: Optional[int] = Field(None, description='Task update time') + + +class KlingImage2VideoResponse(BaseModel): + code: Optional[int] = Field(None, description='Error code') + data: Optional[Data] = None + message: Optional[str] = Field(None, description='Error message') + request_id: Optional[str] = Field(None, description='Request ID') + + +class TaskResult1(BaseModel): + images: Optional[List[KlingImageResult]] = None + + +class Data1(BaseModel): + created_at: Optional[int] = Field(None, description='Task creation time') + task_id: Optional[str] = Field(None, description='Task ID') + task_result: Optional[TaskResult1] = None + task_status: Optional[KlingTaskStatus] = None + task_status_msg: Optional[str] = Field(None, description='Task status information') + updated_at: Optional[int] = Field(None, description='Task update time') + + +class KlingImageGenerationsResponse(BaseModel): + code: Optional[int] = Field(None, description='Error code') + data: Optional[Data1] = None + message: Optional[str] = Field(None, description='Error message') + request_id: Optional[str] = Field(None, description='Request ID') + + +class KlingLipSyncInputObject(BaseModel): + audio_file: Optional[str] = Field( None, - description='Dynamic Brush Configuration List (up to 6 groups). For 5-second videos, trajectory length must not exceed 77 coordinates.', + description='Local Path of Audio File. Supported formats: .mp3/.wav/.m4a/.aac, maximum file size of 5MB. Base64 code.', ) - camera_control: Optional[KlingCameraControl] = None - aspect_ratio: Optional[KlingVideoGenAspectRatio] = '16:9' - duration: Optional[KlingVideoGenDuration] = '5' + audio_type: Optional[KlingAudioUploadType] = None + audio_url: Optional[str] = Field( + None, + description='Audio File Download URL. Supported formats: .mp3/.wav/.m4a/.aac, maximum file size of 5MB.', + ) + mode: KlingLipSyncMode + text: Optional[str] = Field( + None, + description='Text Content for Lip-Sync Video Generation. Required when mode is text2video. Maximum length is 120 characters.', + ) + video_id: Optional[str] = Field( + None, + description='The ID of the video generated by Kling AI. Only supports 5-second and 10-second videos generated within the last 30 days.', + ) + video_url: Optional[str] = Field( + None, + description='Get link for uploaded video. Video files support .mp4/.mov, file size does not exceed 100MB, video length between 2-10s.', + ) + voice_id: Optional[str] = Field( + None, + description='Voice ID. Required when mode is text2video. The system offers a variety of voice options to choose from.', + ) + voice_language: Optional[KlingLipSyncVoiceLanguage] = 'en' + voice_speed: Optional[float] = Field( + 1, + description='Speech Rate. Valid range: 0.8~2.0, accurate to one decimal place.', + ge=0.8, + le=2.0, + ) + + +class KlingLipSyncRequest(BaseModel): callback_url: Optional[AnyUrl] = Field( None, description='The callback notification address. Server will notify when the task status changes.', ) - external_task_id: Optional[str] = Field( - None, - description='Customized Task ID. Must be unique within a single user account.', + input: KlingLipSyncInputObject + + +class TaskResult2(BaseModel): + videos: Optional[List[KlingVideoResult]] = None + + +class Data2(BaseModel): + created_at: Optional[int] = Field(None, description='Task creation time') + task_id: Optional[str] = Field(None, description='Task ID') + task_info: Optional[TaskInfo] = None + task_result: Optional[TaskResult2] = None + task_status: Optional[KlingTaskStatus] = None + updated_at: Optional[int] = Field(None, description='Task update time') + + +class KlingLipSyncResponse(BaseModel): + code: Optional[int] = Field(None, description='Error code') + data: Optional[Data2] = None + message: Optional[str] = Field(None, description='Error message') + request_id: Optional[str] = Field(None, description='Request ID') + + +class KlingSingleImageEffectInput(BaseModel): + duration: KlingSingleImageEffectDuration + image: str = Field( + ..., + description='Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1.', ) + model_name: KlingSingleImageEffectModelName + + +class KlingText2VideoRequest(BaseModel): + aspect_ratio: Optional[KlingVideoGenAspectRatio] = '16:9' + callback_url: Optional[AnyUrl] = Field( + None, description='The callback notification address' + ) + camera_control: Optional[KlingCameraControl] = None + cfg_scale: Optional[KlingVideoGenCfgScale] = Field( + default_factory=lambda: KlingVideoGenCfgScale.model_validate(0.5) + ) + duration: Optional[KlingVideoGenDuration] = '5' + external_task_id: Optional[str] = Field(None, description='Customized Task ID') + mode: Optional[KlingVideoGenMode] = 'std' + model_name: Optional[KlingTextToVideoModelName] = 'kling-v1' + negative_prompt: Optional[str] = Field( + None, description='Negative text prompt', max_length=2500 + ) + prompt: Optional[str] = Field( + None, description='Positive text prompt', max_length=2500 + ) + + +class Data4(BaseModel): + created_at: Optional[int] = Field(None, description='Task creation time') + task_id: Optional[str] = Field(None, description='Task ID') + task_info: Optional[TaskInfo] = None + task_result: Optional[TaskResult2] = None + task_status: Optional[KlingTaskStatus] = None + updated_at: Optional[int] = Field(None, description='Task update time') + + +class KlingText2VideoResponse(BaseModel): + code: Optional[int] = Field(None, description='Error code') + data: Optional[Data4] = None + message: Optional[str] = Field(None, description='Error message') + request_id: Optional[str] = Field(None, description='Request ID') class KlingVideoEffectsInput( @@ -3239,351 +2997,325 @@ class KlingVideoEffectsInput( root: Union[KlingSingleImageEffectInput, KlingDualCharacterEffectInput] -class StripeBillingDetails(BaseModel): - address: Optional[StripeAddress] = None - email: Optional[str] = None - name: Optional[str] = None - phone: Optional[str] = None - tax_id: Optional[Any] = None - - -class StripePaymentMethodDetails(BaseModel): - card: Optional[StripeCardDetails] = None - type: Optional[str] = None - - -class BFLFluxProFillInputs(BaseModel): - image: str = Field( - ..., - description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.', - title='Image', - ) - mask: Optional[str] = Field( +class KlingVideoEffectsRequest(BaseModel): + callback_url: Optional[AnyUrl] = Field( None, - description='A Base64-encoded string representing a mask for the areas you want to modify in the image. The mask should be the same dimensions as the image and in black and white. Black areas (0%) indicate no modification, while white areas (100%) specify areas for inpainting. Optional if you provide an alpha mask in the original image. Validation: The endpoint verifies that the dimensions of the mask match the original image.', - title='Mask', + description='The callback notification address for the result of this task.', + ) + effect_scene: Union[KlingDualCharacterEffectsScene, KlingSingleImageEffectsScene] + external_task_id: Optional[str] = Field( + None, + description='Customized Task ID. Must be unique within a single user account.', + ) + input: KlingVideoEffectsInput + + +class Data5(BaseModel): + created_at: Optional[int] = Field(None, description='Task creation time') + task_id: Optional[str] = Field(None, description='Task ID') + task_info: Optional[TaskInfo] = None + task_result: Optional[TaskResult2] = None + task_status: Optional[KlingTaskStatus] = None + updated_at: Optional[int] = Field(None, description='Task update time') + + +class KlingVideoEffectsResponse(BaseModel): + code: Optional[int] = Field(None, description='Error code') + data: Optional[Data5] = None + message: Optional[str] = Field(None, description='Error message') + request_id: Optional[str] = Field(None, description='Request ID') + + +class KlingVideoExtendRequest(BaseModel): + callback_url: Optional[AnyUrl] = Field( + None, + description='The callback notification address. Server will notify when the task status changes.', + ) + cfg_scale: Optional[KlingVideoGenCfgScale] = Field( + default_factory=lambda: KlingVideoGenCfgScale.model_validate(0.5) + ) + negative_prompt: Optional[str] = Field( + None, + description='Negative text prompt for elements to avoid in the extended video', + max_length=2500, ) prompt: Optional[str] = Field( - '', - description='The description of the changes you want to make. This text guides the inpainting process, allowing you to specify features, styles, or modifications for the masked area.', - examples=['ein fantastisches bild'], - title='Prompt', - ) - steps: Optional[Steps] = Field( - default_factory=lambda: Steps.model_validate(50), - description='Number of steps for the image generation process', - examples=[50], - title='Steps', - ) - prompt_upsampling: Optional[bool] = Field( - False, - description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation', - title='Prompt Upsampling', - ) - seed: Optional[int] = Field( - None, description='Optional seed for reproducibility', title='Seed' - ) - guidance: Optional[Guidance] = Field( - default_factory=lambda: Guidance.model_validate(60), - description='Guidance strength for the image generation process', - title='Guidance', - ) - output_format: Optional[BFLOutputFormat] = Field( - 'jpeg', - description="Output format for the generated image. Can be 'jpeg' or 'png'.", - ) - safety_tolerance: Optional[int] = Field( - 2, - description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', - examples=[2], - ge=0, - le=6, - title='Safety Tolerance', - ) - webhook_url: Optional[WebhookUrl] = Field( - None, description='URL to receive webhook notifications', title='Webhook Url' - ) - webhook_secret: Optional[str] = Field( None, - description='Optional secret for webhook signature verification', - title='Webhook Secret', + description='Positive text prompt for guiding the video extension', + max_length=2500, ) - - -class BFLHTTPValidationError(BaseModel): - detail: Optional[List[BFLValidationError]] = Field(None, title='Detail') - - -class BFLFluxProExpandInputs(BaseModel): - image: str = Field( - ..., - description='A Base64-encoded string representing the image you wish to expand.', - title='Image', - ) - top: Optional[Top] = Field( - 0, description='Number of pixels to expand at the top of the image', title='Top' - ) - bottom: Optional[Bottom] = Field( - 0, - description='Number of pixels to expand at the bottom of the image', - title='Bottom', - ) - left: Optional[Left] = Field( - 0, - description='Number of pixels to expand on the left side of the image', - title='Left', - ) - right: Optional[Right] = Field( - 0, - description='Number of pixels to expand on the right side of the image', - title='Right', - ) - prompt: Optional[str] = Field( - '', - description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.', - examples=['ein fantastisches bild'], - title='Prompt', - ) - steps: Optional[Steps] = Field( - default_factory=lambda: Steps.model_validate(50), - description='Number of steps for the image generation process', - examples=[50], - title='Steps', - ) - prompt_upsampling: Optional[bool] = Field( - False, - description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation', - title='Prompt Upsampling', - ) - seed: Optional[int] = Field( - None, description='Optional seed for reproducibility', title='Seed' - ) - guidance: Optional[Guidance] = Field( - default_factory=lambda: Guidance.model_validate(60), - description='Guidance strength for the image generation process', - title='Guidance', - ) - output_format: Optional[BFLOutputFormat] = Field( - 'jpeg', - description="Output format for the generated image. Can be 'jpeg' or 'png'.", - ) - safety_tolerance: Optional[int] = Field( - 2, - description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', - examples=[2], - ge=0, - le=6, - title='Safety Tolerance', - ) - webhook_url: Optional[WebhookUrl] = Field( - None, description='URL to receive webhook notifications', title='Webhook Url' - ) - webhook_secret: Optional[str] = Field( + video_id: Optional[str] = Field( None, - description='Optional secret for webhook signature verification', - title='Webhook Secret', + description='The ID of the video to be extended. Supports videos generated by text-to-video, image-to-video, and previous video extension operations. Cannot exceed 3 minutes total duration after extension.', ) -class BFLCannyInputs(BaseModel): - prompt: str = Field( - ..., - description='Text prompt for image generation', - examples=['ein fantastisches bild'], - title='Prompt', - ) - control_image: Optional[str] = Field( - None, - description='Base64 encoded image to use as control input if no preprocessed image is provided', - title='Control Image', - ) - preprocessed_image: Optional[str] = Field( - None, - description='Optional pre-processed image that will bypass the control preprocessing step', - title='Preprocessed Image', - ) - canny_low_threshold: Optional[CannyLowThreshold] = Field( - default_factory=lambda: CannyLowThreshold.model_validate(50), - description='Low threshold for Canny edge detection', - title='Canny Low Threshold', - ) - canny_high_threshold: Optional[CannyHighThreshold] = Field( - default_factory=lambda: CannyHighThreshold.model_validate(200), - description='High threshold for Canny edge detection', - title='Canny High Threshold', - ) - prompt_upsampling: Optional[bool] = Field( - False, - description='Whether to perform upsampling on the prompt', - title='Prompt Upsampling', - ) - seed: Optional[int] = Field( - None, - description='Optional seed for reproducibility', - examples=[42], - title='Seed', - ) - steps: Optional[Steps2] = Field( - default_factory=lambda: Steps2.model_validate(50), - description='Number of steps for the image generation process', - title='Steps', - ) - output_format: Optional[BFLOutputFormat] = Field( - 'jpeg', - description="Output format for the generated image. Can be 'jpeg' or 'png'.", - ) - guidance: Optional[Guidance2] = Field( - default_factory=lambda: Guidance2.model_validate(30), - description='Guidance strength for the image generation process', - title='Guidance', - ) - safety_tolerance: Optional[int] = Field( - 2, - description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', - ge=0, - le=6, - title='Safety Tolerance', - ) - webhook_url: Optional[WebhookUrl] = Field( - None, description='URL to receive webhook notifications', title='Webhook Url' - ) - webhook_secret: Optional[str] = Field( - None, - description='Optional secret for webhook signature verification', - title='Webhook Secret', - ) +class Data6(BaseModel): + created_at: Optional[int] = Field(None, description='Task creation time') + task_id: Optional[str] = Field(None, description='Task ID') + task_info: Optional[TaskInfo] = None + task_result: Optional[TaskResult2] = None + task_status: Optional[KlingTaskStatus] = None + updated_at: Optional[int] = Field(None, description='Task update time') -class BFLDepthInputs(BaseModel): - prompt: str = Field( - ..., - description='Text prompt for image generation', - examples=['ein fantastisches bild'], - title='Prompt', - ) - control_image: Optional[str] = Field( - None, - description='Base64 encoded image to use as control input', - title='Control Image', - ) - preprocessed_image: Optional[str] = Field( - None, - description='Optional pre-processed image that will bypass the control preprocessing step', - title='Preprocessed Image', - ) - prompt_upsampling: Optional[bool] = Field( - False, - description='Whether to perform upsampling on the prompt', - title='Prompt Upsampling', - ) - seed: Optional[int] = Field( - None, - description='Optional seed for reproducibility', - examples=[42], - title='Seed', - ) - steps: Optional[Steps2] = Field( - default_factory=lambda: Steps2.model_validate(50), - description='Number of steps for the image generation process', - title='Steps', - ) - output_format: Optional[BFLOutputFormat] = Field( - 'jpeg', - description="Output format for the generated image. Can be 'jpeg' or 'png'.", - ) - guidance: Optional[Guidance2] = Field( - default_factory=lambda: Guidance2.model_validate(15), - description='Guidance strength for the image generation process', - title='Guidance', - ) - safety_tolerance: Optional[int] = Field( - 2, - description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', - ge=0, - le=6, - title='Safety Tolerance', - ) - webhook_url: Optional[WebhookUrl] = Field( - None, description='URL to receive webhook notifications', title='Webhook Url' - ) - webhook_secret: Optional[str] = Field( - None, - description='Optional secret for webhook signature verification', - title='Webhook Secret', - ) - - -class Controls(BaseModel): - artistic_level: Optional[int] = Field( - None, - description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity.', - ge=0, - le=5, - ) - colors: Optional[List[RGBColor]] = Field( - None, description='An array of preferable colors' - ) - background_color: Optional[RGBColor] = Field( - None, description='Use given color as a desired background color' - ) - no_text: Optional[bool] = Field(None, description='Do not embed text layouts') - - -class RecraftImageGenerationRequest(BaseModel): - prompt: str = Field( - ..., description='The text prompt describing the image to generate' - ) - model: str = Field( - ..., description='The model to use for generation (e.g., "recraftv3")' - ) - style: Optional[str] = Field( - None, - description='The style to apply to the generated image (e.g., "digital_illustration")', - ) - style_id: Optional[str] = Field( - None, - description='The style ID to apply to the generated image (e.g., "123e4567-e89b-12d3-a456-426614174000"). If style_id is provided, style should not be provided.', - ) - size: str = Field( - ..., description='The size of the generated image (e.g., "1024x1024")' - ) - controls: Optional[Controls] = Field( - None, description='The controls for the generated image' - ) - n: int = Field(..., description='The number of images to generate', ge=1, le=4) - - -class LumaKeyframes(BaseModel): - frame0: Optional[LumaKeyframe] = None - frame1: Optional[LumaKeyframe] = None +class KlingVideoExtendResponse(BaseModel): + code: Optional[int] = Field(None, description='Error code') + data: Optional[Data6] = None + message: Optional[str] = Field(None, description='Error message') + request_id: Optional[str] = Field(None, description='Request ID') class LumaGenerationRequest(BaseModel): - generation_type: Optional[GenerationType] = 'video' - prompt: str = Field(..., description='The prompt of the generation') aspect_ratio: LumaAspectRatio - loop: Optional[bool] = Field(None, description='Whether to loop the video') - keyframes: Optional[LumaKeyframes] = None callback_url: Optional[AnyUrl] = Field( None, description='The callback URL of the generation, a POST request with Generation object will be sent to the callback URL when the generation is dreaming, completed, or failed', ) - model: LumaVideoModel - resolution: LumaVideoModelOutputResolution duration: LumaVideoModelOutputDuration + generation_type: Optional[GenerationType1] = 'video' + keyframes: Optional[LumaKeyframes] = None + loop: Optional[bool] = Field(None, description='Whether to loop the video') + model: LumaVideoModel + prompt: str = Field(..., description='The prompt of the generation') + resolution: LumaVideoModelOutputResolution + + +class CharacterRef(BaseModel): + identity0: Optional[LumaImageIdentity] = None + + +class LumaImageGenerationRequest(BaseModel): + aspect_ratio: Optional[LumaAspectRatio] = '16:9' + callback_url: Optional[AnyUrl] = Field( + None, description='The callback URL for the generation' + ) + character_ref: Optional[CharacterRef] = None + generation_type: Optional[GenerationType2] = 'image' + image_ref: Optional[List[LumaImageRef]] = None + model: Optional[LumaImageModel] = 'photon-1' + modify_image_ref: Optional[LumaModifyImageRef] = None + prompt: Optional[str] = Field(None, description='The prompt of the generation') + style_ref: Optional[List[LumaImageRef]] = None + + +class LumaUpscaleVideoGenerationRequest(BaseModel): + callback_url: Optional[AnyUrl] = Field( + None, description='The callback URL for the upscale' + ) + generation_type: Optional[GenerationType3] = 'upscale_video' + resolution: Optional[LumaVideoModelOutputResolution] = None + + +class OutputContent(RootModel[Union[OutputTextContent, OutputAudioContent]]): + root: Union[OutputTextContent, OutputAudioContent] + + +class OutputMessage(BaseModel): + content: List[OutputContent] = Field(..., description='The content of the message') + role: Role4 = Field(..., description='The role of the message') + type: Type14 = Field(..., description='The type of output item') + + +class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel): + duration: Optional[PikaDurationEnum] = 5 + image: Optional[StrictBytes] = Field(None, title='Image') + negativePrompt: Optional[str] = Field(None, title='Negativeprompt') + promptText: Optional[str] = Field(None, title='Prompttext') + resolution: Optional[PikaResolutionEnum] = '1080p' + seed: Optional[int] = Field(None, title='Seed') + + +class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel): + duration: Optional[int] = Field(None, ge=5, le=10, title='Duration') + keyFrames: Optional[List[StrictBytes]] = Field( + None, description='Array of keyframe images', title='Keyframes' + ) + negativePrompt: Optional[str] = Field(None, title='Negativeprompt') + promptText: str = Field(..., title='Prompttext') + resolution: Optional[PikaResolutionEnum] = '1080p' + seed: Optional[int] = Field(None, title='Seed') + + +class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel): + aspectRatio: Optional[float] = Field( + 1.7777777777777777, + description='Aspect ratio (width / height)', + ge=0.4, + le=2.5, + title='Aspectratio', + ) + duration: Optional[PikaDurationEnum] = 5 + negativePrompt: Optional[str] = Field(None, title='Negativeprompt') + promptText: str = Field(..., title='Prompttext') + resolution: Optional[PikaResolutionEnum] = '1080p' + seed: Optional[int] = Field(None, title='Seed') + + +class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel): + image: Optional[StrictBytes] = Field(None, title='Image') + negativePrompt: Optional[str] = Field(None, title='Negativeprompt') + pikaffect: Optional[Pikaffect] = None + promptText: Optional[str] = Field(None, title='Prompttext') + seed: Optional[int] = Field(None, title='Seed') + + +class PikaHTTPValidationError(BaseModel): + detail: Optional[List[PikaValidationError]] = Field(None, title='Detail') + + +class Reasoning(BaseModel): + effort: Optional[ReasoningEffort] = 'medium' + generate_summary: Optional[GenerateSummary] = Field( + None, + description="**Deprecated:** use `summary` instead.\n\nA summary of the reasoning performed by the model. This can be\nuseful for debugging and understanding the model's reasoning process.\nOne of `auto`, `concise`, or `detailed`.\n", + ) + summary: Optional[Summary] = Field( + None, + description="A summary of the reasoning performed by the model. This can be\nuseful for debugging and understanding the model's reasoning process.\nOne of `auto`, `concise`, or `detailed`.\n", + ) + + +class ResponseError(BaseModel): + code: ResponseErrorCode + message: str = Field(..., description='A human-readable description of the error.') + + +class Rodin3DDownloadResponse(BaseModel): + list: Optional[RodinResourceItem] = None + + +class Rodin3DGenerateRequest(BaseModel): + images: str = Field(..., description='The reference images to generate 3D Assets.') + material: Optional[RodinMaterialType] = None + mesh_mode: Optional[RodinMeshModeType] = None + quality: Optional[RodinQualityType] = None + seed: Optional[int] = Field(None, description='Seed.') + tier: Optional[RodinTierType] = None + + +class Rodin3DGenerateResponse(BaseModel): + jobs: Optional[RodinGenerateJobsData] = None + message: Optional[str] = Field(None, description='message') + prompt: Optional[str] = Field(None, description='prompt') + submit_time: Optional[str] = Field(None, description='Time') + uuid: Optional[str] = Field(None, description='Task UUID') + + +class RunwayImageToVideoRequest(BaseModel): + duration: RunwayDurationEnum + model: RunwayModelEnum + promptImage: RunwayPromptImageObject + promptText: Optional[str] = Field( + None, description='Text prompt for the generation', max_length=1000 + ) + ratio: RunwayAspectRatioEnum + seed: int = Field( + ..., description='Random seed for generation', ge=0, le=4294967295 + ) + + +class TextResponseFormatConfiguration( + RootModel[ + Union[ + ResponseFormatText, TextResponseFormatJsonSchema, ResponseFormatJsonObject + ] + ] +): + root: Union[ + ResponseFormatText, TextResponseFormatJsonSchema, ResponseFormatJsonObject + ] = Field( + ..., + description='An object specifying the format that the model must output.\n\nConfiguring `{ "type": "json_schema" }` enables Structured Outputs, \nwhich ensures the model will match your supplied JSON schema. Learn more in the \n[Structured Outputs guide](/docs/guides/structured-outputs).\n\nThe default format is `{ "type": "text" }` with no additional options.\n\n**Not recommended for gpt-4o and newer models:**\n\nSetting to `{ "type": "json_object" }` enables the older JSON mode, which\nensures the message the model generates is valid JSON. Using `json_schema`\nis preferred for models that support it.\n', + ) + + +class Tool( + RootModel[ + Union[ + FileSearchTool, FunctionTool, WebSearchPreviewTool, ComputerUsePreviewTool + ] + ] +): + root: Union[ + FileSearchTool, FunctionTool, WebSearchPreviewTool, ComputerUsePreviewTool + ] = Field(..., discriminator='type') + + +class EasyInputMessage(BaseModel): + content: Union[str, InputMessageContentList] = Field( + ..., + description='Text, image, or audio input to the model, used to generate a response.\nCan also contain previous assistant responses.\n', + ) + role: Role = Field( + ..., + description='The role of the message input. One of `user`, `assistant`, `system`, or\n`developer`.\n', + ) + type: Optional[Type2] = Field( + None, description='The type of the message input. Always `message`.\n' + ) + + +class GeminiContent(BaseModel): + parts: List[GeminiPart] + role: Role1 = Field(..., examples=['user']) + + +class GeminiGenerateContentRequest(BaseModel): + contents: List[GeminiContent] + generationConfig: Optional[GeminiGenerationConfig] = None + safetySettings: Optional[List[GeminiSafetySetting]] = None + systemInstruction: Optional[GeminiSystemInstructionContent] = None + tools: Optional[List[GeminiTool]] = None + videoMetadata: Optional[GeminiVideoMetadata] = None + + +class ImagenGenerateImageRequest(BaseModel): + instances: List[ImagenImageGenerationInstance] + parameters: ImagenImageGenerationParameters + + +class InputMessage(BaseModel): + content: Optional[InputMessageContentList] = None + role: Optional[Role3] = None + status: Optional[Status2] = None + type: Optional[Type9] = None + + +class Item( + RootModel[ + Union[ + InputMessage, + OutputMessage, + FileSearchToolCall, + ComputerToolCall, + WebSearchToolCall, + FunctionToolCall, + ReasoningItem, + ] + ] +): + root: Union[ + InputMessage, + OutputMessage, + FileSearchToolCall, + ComputerToolCall, + WebSearchToolCall, + FunctionToolCall, + ReasoningItem, + ] = Field(..., description='Content item used to generate a response.\n') class LumaGeneration(BaseModel): - id: Optional[UUID] = Field(None, description='The ID of the generation') - generation_type: Optional[LumaGenerationType] = None - state: Optional[LumaState] = None - failure_reason: Optional[str] = Field( - None, description='The reason for the state of the generation' - ) + assets: Optional[LumaAssets] = None created_at: Optional[datetime] = Field( None, description='The date and time when the generation was created' ) - assets: Optional[LumaAssets] = None + failure_reason: Optional[str] = Field( + None, description='The reason for the state of the generation' + ) + generation_type: Optional[LumaGenerationType] = None + id: Optional[UUID] = Field(None, description='The ID of the generation') model: Optional[str] = Field(None, description='The model used for the generation') request: Optional[ Union[ @@ -3593,237 +3325,129 @@ class LumaGeneration(BaseModel): LumaAudioGenerationRequest, ] ] = Field(None, description='The request of the generation') + state: Optional[LumaState] = None -class RunwayImageToVideoRequest(BaseModel): - promptImage: RunwayPromptImageObject - seed: int = Field( - ..., description='Random seed for generation', ge=0, le=4294967295 +class OutputItem( + RootModel[ + Union[ + OutputMessage, + FileSearchToolCall, + FunctionToolCall, + WebSearchToolCall, + ComputerToolCall, + ReasoningItem, + ] + ] +): + root: Union[ + OutputMessage, + FileSearchToolCall, + FunctionToolCall, + WebSearchToolCall, + ComputerToolCall, + ReasoningItem, + ] + + +class Text(BaseModel): + format: Optional[TextResponseFormatConfiguration] = None + + +class ResponseProperties(BaseModel): + instructions: Optional[str] = Field( + None, + description="Inserts a system (or developer) message as the first item in the model's context.\n\nWhen using along with `previous_response_id`, the instructions from a previous\nresponse will not be carried over to the next response. This makes it simple\nto swap out system (or developer) messages in new responses.\n", ) - model: RunwayModelEnum = Field(..., description='Model to use for generation') - promptText: Optional[str] = Field( - None, description='Text prompt for the generation', max_length=1000 + max_output_tokens: Optional[int] = Field( + None, + description='An upper bound for the number of tokens that can be generated for a response, including visible output tokens and [reasoning tokens](/docs/guides/reasoning).\n', ) - duration: RunwayDurationEnum = Field( - ..., description='The number of seconds of duration for the output video.' + model: Optional[OpenAIModels] = None + previous_response_id: Optional[str] = Field( + None, + description='The unique ID of the previous response to the model. Use this to\ncreate multi-turn conversations. Learn more about \n[conversation state](/docs/guides/conversation-state).\n', ) - ratio: RunwayAspectRatioEnum = Field( + reasoning: Optional[Reasoning] = None + text: Optional[Text] = None + tool_choice: Optional[ + Union[ToolChoiceOptions, ToolChoiceTypes, ToolChoiceFunction] + ] = Field( + None, + description='How the model should select which tool (or tools) to use when generating\na response. See the `tools` parameter to see how to specify which tools\nthe model can call.\n', + ) + tools: Optional[List[Tool]] = None + truncation: Optional[Truncation1] = Field( + 'disabled', + description="The truncation strategy to use for the model response.\n- `auto`: If the context of this response and previous ones exceeds\n the model's context window size, the model will truncate the \n response to fit the context window by dropping input items in the\n middle of the conversation. \n- `disabled` (default): If a model response will exceed the context window \n size for a model, the request will fail with a 400 error.\n", + ) + + +class GeminiCandidate(BaseModel): + citationMetadata: Optional[GeminiCitationMetadata] = None + content: Optional[GeminiContent] = None + finishReason: Optional[str] = None + safetyRatings: Optional[List[GeminiSafetyRating]] = None + + +class GeminiGenerateContentResponse(BaseModel): + candidates: Optional[List[GeminiCandidate]] = None + promptFeedback: Optional[GeminiPromptFeedback] = None + + +class InputItem(RootModel[Union[EasyInputMessage, Item]]): + root: Union[EasyInputMessage, Item] + + +class OpenAICreateResponse(CreateModelResponseProperties, ResponseProperties): + include: Optional[List[Includable]] = Field( + None, + description='Specify additional output data to include in the model response. Currently\nsupported values are:\n- `file_search_call.results`: Include the search results of\n the file search tool call.\n- `message.input_image.image_url`: Include image urls from the input message.\n- `computer_call_output.output.image_url`: Include image urls from the computer call output.\n', + ) + input: Union[str, List[InputItem]] = Field( ..., - description='The resolution (aspect ratio) of the output video. Allowable values depend on the selected model. 1280:768 and 768:1280 are only supported for gen3a_turbo.', + description='Text, image, or file inputs to the model, used to generate a response.\n\nLearn more:\n- [Text inputs and outputs](/docs/guides/text)\n- [Image inputs](/docs/guides/images)\n- [File inputs](/docs/guides/pdf-files)\n- [Conversation state](/docs/guides/conversation-state)\n- [Function calling](/docs/guides/function-calling)\n', ) - - -class RunwayTaskStatusResponse(BaseModel): - id: Optional[str] = Field(None, description='Task ID') - status: Optional[RunwayTaskStatusEnum] = Field(None, description='Task status') - createdAt: Optional[datetime] = Field(None, description='Task creation timestamp') - output: Optional[List[str]] = Field(None, description='Array of output video URLs') - - -class PikaHTTPValidationError(BaseModel): - detail: Optional[List[PikaValidationError]] = Field(None, title='Detail') - - -class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel): - promptText: str = Field(..., title='Prompttext') - negativePrompt: Optional[str] = Field(None, title='Negativeprompt') - seed: Optional[int] = Field(None, title='Seed') - resolution: Optional[PikaResolutionEnum] = Field('1080p', title='Resolution') - duration: Optional[PikaDurationEnum] = Field(5, title='Duration') - aspectRatio: Optional[float] = Field( - 1.7777777777777777, - description='Aspect ratio (width / height)', - ge=0.4, - le=2.5, - title='Aspectratio', + parallel_tool_calls: Optional[bool] = Field( + True, description='Whether to allow the model to run tool calls in parallel.\n' ) - - -class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel): - image: Optional[StrictBytes] = Field(None, title='Image') - promptText: Optional[str] = Field(None, title='Prompttext') - negativePrompt: Optional[str] = Field(None, title='Negativeprompt') - seed: Optional[int] = Field(None, title='Seed') - resolution: Optional[PikaResolutionEnum] = Field('1080p', title='Resolution') - duration: Optional[PikaDurationEnum] = Field(5, title='Duration') - - -class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel): - keyFrames: Optional[List[StrictBytes]] = Field( - None, description='Array of keyframe images', title='Keyframes' + store: Optional[bool] = Field( + True, + description='Whether to store the generated model response for later retrieval via\nAPI.\n', ) - promptText: str = Field(..., title='Prompttext') - negativePrompt: Optional[str] = Field(None, title='Negativeprompt') - seed: Optional[int] = Field(None, title='Seed') - resolution: Optional[PikaResolutionEnum] = Field('1080p', title='Resolution') - duration: Optional[int] = Field(None, ge=5, le=10, title='Duration') + stream: Optional[bool] = Field( + False, + description='If set to true, the model response data will be streamed to the client\nas it is generated using [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format).\nSee the [Streaming section below](/docs/api-reference/responses-streaming)\nfor more information.\n', + ) + usage: Optional[ResponseUsage] = None -class PikaVideoResponse(BaseModel): - id: str = Field(..., title='Id') - status: PikaStatusEnum = Field( - ..., description='The status of the video', title='Status' - ) - url: Optional[str] = Field(None, title='Url') - progress: Optional[int] = Field(None, title='Progress') - - -class Node(BaseModel): - id: Optional[str] = Field(None, description='The unique identifier of the node.') - name: Optional[str] = Field(None, description='The display name of the node.') - category: Optional[str] = Field(None, description='The category of the node.') - description: Optional[str] = None - author: Optional[str] = None - license: Optional[str] = Field( - None, description="The path to the LICENSE file in the node's repository." - ) - icon: Optional[str] = Field(None, description="URL to the node's icon.") - repository: Optional[str] = Field(None, description="URL to the node's repository.") - tags: Optional[List[str]] = None - latest_version: Optional[NodeVersion] = Field( - None, description='The latest version of the node.' - ) - rating: Optional[float] = Field(None, description='The average rating of the node.') - downloads: Optional[int] = Field( - None, description='The number of downloads of the node.' - ) - publisher: Optional[Publisher] = Field( - None, description='The publisher of the node.' - ) - status: Optional[NodeStatus] = Field(None, description='The status of the node.') - status_detail: Optional[str] = Field( - None, description='The status detail of the node.' - ) - translations: Optional[Dict[str, Dict[str, Any]]] = None - - -class KlingVideoEffectsRequest(BaseModel): - effect_scene: Union[KlingDualCharacterEffectsScene, KlingSingleImageEffectsScene] - input: KlingVideoEffectsInput - callback_url: Optional[AnyUrl] = Field( +class OpenAIResponse(ModelResponseProperties, ResponseProperties): + created_at: Optional[float] = Field( None, - description='The callback notification address for the result of this task.', + description='Unix timestamp (in seconds) of when this Response was created.', ) - external_task_id: Optional[str] = Field( + error: Optional[ResponseError] = None + id: Optional[str] = Field(None, description='Unique identifier for this Response.') + incomplete_details: Optional[IncompleteDetails] = Field( + None, description='Details about why the response is incomplete.\n' + ) + object: Optional[Object] = Field( + None, description='The object type of this resource - always set to `response`.' + ) + output: Optional[List[OutputItem]] = Field( None, - description='Customized Task ID. Must be unique within a single user account.', + description="An array of content items generated by the model.\n\n- The length and order of items in the `output` array is dependent\n on the model's response.\n- Rather than accessing the first item in the `output` array and \n assuming it's an `assistant` message with the content generated by\n the model, you might consider using the `output_text` property where\n supported in SDKs.\n", ) - - -class StripeCharge(BaseModel): - id: Optional[str] = None - object: Optional[Object2] = None - amount: Optional[int] = None - amount_captured: Optional[int] = None - amount_refunded: Optional[int] = None - application: Optional[str] = None - application_fee: Optional[str] = None - application_fee_amount: Optional[int] = None - balance_transaction: Optional[str] = None - billing_details: Optional[StripeBillingDetails] = None - calculated_statement_descriptor: Optional[str] = None - captured: Optional[bool] = None - created: Optional[int] = None - currency: Optional[str] = None - customer: Optional[str] = None - description: Optional[str] = None - destination: Optional[Any] = None - dispute: Optional[Any] = None - disputed: Optional[bool] = None - failure_balance_transaction: Optional[Any] = None - failure_code: Optional[Any] = None - failure_message: Optional[Any] = None - fraud_details: Optional[Dict[str, Any]] = None - invoice: Optional[Any] = None - livemode: Optional[bool] = None - metadata: Optional[Dict[str, Any]] = None - on_behalf_of: Optional[Any] = None - order: Optional[Any] = None - outcome: Optional[StripeOutcome] = None - paid: Optional[bool] = None - payment_intent: Optional[str] = None - payment_method: Optional[str] = None - payment_method_details: Optional[StripePaymentMethodDetails] = None - radar_options: Optional[Dict[str, Any]] = None - receipt_email: Optional[str] = None - receipt_number: Optional[str] = None - receipt_url: Optional[str] = None - refunded: Optional[bool] = None - refunds: Optional[StripeRefundList] = None - review: Optional[Any] = None - shipping: Optional[StripeShipping] = None - source: Optional[Any] = None - source_transfer: Optional[Any] = None - statement_descriptor: Optional[Any] = None - statement_descriptor_suffix: Optional[Any] = None - status: Optional[str] = None - transfer_data: Optional[Any] = None - transfer_group: Optional[Any] = None - - -class StripeChargeList(BaseModel): - object: Optional[str] = None - data: Optional[List[StripeCharge]] = None - has_more: Optional[bool] = None - total_count: Optional[int] = None - url: Optional[str] = None - - -class StripePaymentIntent(BaseModel): - id: Optional[str] = None - object: Optional[Object1] = None - amount: Optional[int] = None - amount_capturable: Optional[int] = None - amount_details: Optional[StripeAmountDetails] = None - amount_received: Optional[int] = None - application: Optional[str] = None - application_fee_amount: Optional[int] = None - automatic_payment_methods: Optional[Any] = None - canceled_at: Optional[int] = None - cancellation_reason: Optional[str] = None - capture_method: Optional[str] = None - charges: Optional[StripeChargeList] = None - client_secret: Optional[str] = None - confirmation_method: Optional[str] = None - created: Optional[int] = None - currency: Optional[str] = None - customer: Optional[str] = None - description: Optional[str] = None - invoice: Optional[str] = None - last_payment_error: Optional[Any] = None - latest_charge: Optional[str] = None - livemode: Optional[bool] = None - metadata: Optional[Dict[str, Any]] = None - next_action: Optional[Any] = None - on_behalf_of: Optional[Any] = None - payment_method: Optional[str] = None - payment_method_configuration_details: Optional[Any] = None - payment_method_options: Optional[StripePaymentMethodOptions] = None - payment_method_types: Optional[List[str]] = None - processing: Optional[Any] = None - receipt_email: Optional[str] = None - review: Optional[Any] = None - setup_future_usage: Optional[Any] = None - shipping: Optional[StripeShipping] = None - source: Optional[Any] = None - statement_descriptor: Optional[Any] = None - statement_descriptor_suffix: Optional[Any] = None - status: Optional[str] = None - transfer_data: Optional[Any] = None - transfer_group: Optional[Any] = None - - -class Data8(BaseModel): - object: Optional[StripePaymentIntent] = None - - -class StripeEvent(BaseModel): - id: str - object: Object - api_version: Optional[str] = None - created: Optional[int] = None - data: Data8 - livemode: Optional[bool] = None - pending_webhooks: Optional[int] = None - request: Optional[StripeRequestInfo] = None - type: Type + output_text: Optional[str] = Field( + None, + description='SDK-only convenience property that contains the aggregated text output \nfrom all `output_text` items in the `output` array, if any are present. \nSupported in the Python and JavaScript SDKs.\n', + ) + parallel_tool_calls: Optional[bool] = Field( + True, description='Whether to allow the model to run tool calls in parallel.\n' + ) + status: Optional[Status6] = Field( + None, + description='The status of the response generation. One of `completed`, `failed`, `in_progress`, or `incomplete`.', + ) + usage: Optional[ResponseUsage] = None diff --git a/comfy_api_nodes/apis/bfl_api.py b/comfy_api_nodes/apis/bfl_api.py index c189038fb..0e90aef7c 100644 --- a/comfy_api_nodes/apis/bfl_api.py +++ b/comfy_api_nodes/apis/bfl_api.py @@ -108,6 +108,24 @@ class BFLFluxProGenerateRequest(BaseModel): # ) +class BFLFluxKontextProGenerateRequest(BaseModel): + prompt: str = Field(..., description='The text prompt for what you wannt to edit.') + input_image: Optional[str] = Field(None, description='Image to edit in base64 format') + seed: Optional[int] = Field(None, description='The seed value for reproducibility.') + guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process') + steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process') + safety_tolerance: Optional[conint(ge=0, le=2)] = Field( + 2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.' + ) + output_format: Optional[BFLOutputFormat] = Field( + BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] + ) + aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.') + prompt_upsampling: Optional[bool] = Field( + None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' + ) + + class BFLFluxProUltraGenerateRequest(BaseModel): prompt: str = Field(..., description='The text prompt for image generation.') prompt_upsampling: Optional[bool] = Field( diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py index cff52714f..2a4bac88b 100644 --- a/comfy_api_nodes/apis/client.py +++ b/comfy_api_nodes/apis/client.py @@ -94,15 +94,19 @@ from __future__ import annotations import logging import time import io -from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable +import socket +from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple from enum import Enum import json import requests -from urllib.parse import urljoin +from urllib.parse import urljoin, urlparse from pydantic import BaseModel, Field +import uuid # For generating unique operation IDs +from server import PromptServer from comfy.cli_args import args from comfy import utils +from . import request_logger T = TypeVar("T", bound=BaseModel) R = TypeVar("R", bound=BaseModel) @@ -111,6 +115,21 @@ P = TypeVar("P", bound=BaseModel) # For poll response PROGRESS_BAR_MAX = 100 +class NetworkError(Exception): + """Base exception for network-related errors with diagnostic information.""" + pass + + +class LocalNetworkError(NetworkError): + """Exception raised when local network connectivity issues are detected.""" + pass + + +class ApiServerError(NetworkError): + """Exception raised when the API server is unreachable but internet is working.""" + pass + + class EmptyRequest(BaseModel): """Base class for empty request bodies. For GET requests, fields will be sent as query parameters.""" @@ -120,7 +139,7 @@ class EmptyRequest(BaseModel): class UploadRequest(BaseModel): file_name: str = Field(..., description="Filename to upload") - content_type: str | None = Field( + content_type: Optional[str] = Field( None, description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", ) @@ -141,7 +160,7 @@ class HttpMethod(str, Enum): class ApiClient: """ - Client for making HTTP requests to an API with authentication and error handling. + Client for making HTTP requests to an API with authentication, error handling, and retry logic. """ def __init__( @@ -151,12 +170,26 @@ class ApiClient: comfy_api_key: Optional[str] = None, timeout: float = 3600.0, verify_ssl: bool = True, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff_factor: float = 2.0, + retry_status_codes: Optional[Tuple[int, ...]] = None, ): self.base_url = base_url self.auth_token = auth_token self.comfy_api_key = comfy_api_key self.timeout = timeout self.verify_ssl = verify_ssl + self.max_retries = max_retries + self.retry_delay = retry_delay + self.retry_backoff_factor = retry_backoff_factor + # Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests), + # 500, 502, 503, 504 (Server Errors) + self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504) + + def _generate_operation_id(self, path: str) -> str: + """Generates a unique operation ID for logging.""" + return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}" def _create_json_payload_args( self, @@ -211,6 +244,56 @@ class ApiClient: return headers + def _check_connectivity(self, target_url: str) -> Dict[str, bool]: + """ + Check connectivity to determine if network issues are local or server-related. + + Args: + target_url: URL to check connectivity to + + Returns: + Dictionary with connectivity status details + """ + results = { + "internet_accessible": False, + "api_accessible": False, + "is_local_issue": False, + "is_api_issue": False + } + + # First check basic internet connectivity using a reliable external site + try: + # Use a reliable external domain for checking basic connectivity + check_response = requests.get("https://www.google.com", + timeout=5.0, + verify=self.verify_ssl) + if check_response.status_code < 500: + results["internet_accessible"] = True + except (requests.RequestException, socket.error): + results["internet_accessible"] = False + results["is_local_issue"] = True + return results + + # Now check API server connectivity + try: + # Extract domain from the target URL to do a simpler health check + parsed_url = urlparse(target_url) + api_base = f"{parsed_url.scheme}://{parsed_url.netloc}" + + # Try to reach the API domain + api_response = requests.get(f"{api_base}/health", timeout=5.0, verify=self.verify_ssl) + if api_response.status_code < 500: + results["api_accessible"] = True + else: + results["api_accessible"] = False + results["is_api_issue"] = True + except requests.RequestException: + results["api_accessible"] = False + # If we can reach the internet but not the API, it's an API issue + results["is_api_issue"] = True + + return results + def request( self, method: str, @@ -221,9 +304,10 @@ class ApiClient: headers: Optional[Dict[str, str]] = None, content_type: str = "application/json", multipart_parser: Callable = None, + retry_count: int = 0, # Used internally for tracking retries ) -> Dict[str, Any]: """ - Make an HTTP request to the API + Make an HTTP request to the API with automatic retries for transient errors. Args: method: HTTP method (GET, POST, etc.) @@ -233,14 +317,19 @@ class ApiClient: files: Files to upload headers: Additional headers content_type: Content type of the request. Defaults to application/json. + retry_count: Internal parameter for tracking retries, do not set manually Returns: Parsed JSON response Raises: - requests.RequestException: If the request fails + LocalNetworkError: If local network connectivity issues are detected + ApiServerError: If the API server is unreachable but internet is working + Exception: For other request failures """ - url = urljoin(self.base_url, path) + # Use urljoin but ensure path is relative to avoid absolute path behavior + relative_path = path.lstrip('/') + url = urljoin(self.base_url, relative_path) self.check_auth(self.auth_token, self.comfy_api_key) # Combine default headers with any provided headers request_headers = self.get_headers() @@ -265,6 +354,16 @@ class ApiClient: else: payload_args = self._create_json_payload_args(data, request_headers) + operation_id = self._generate_operation_id(path) + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=request_headers, + request_params=params, + request_data=data if content_type == "application/json" else "[form-data or other]" + ) + try: response = requests.request( method=method, @@ -275,50 +374,228 @@ class ApiClient: **payload_args, ) + # Check if we should retry based on status code + if (response.status_code in self.retry_status_codes and + retry_count < self.max_retries): + + # Calculate delay with exponential backoff + delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) + + logging.warning( + f"Request failed with status {response.status_code}. " + f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" + ) + + time.sleep(delay) + return self.request( + method=method, + path=path, + params=params, + data=data, + files=files, + headers=headers, + content_type=content_type, + multipart_parser=multipart_parser, + retry_count=retry_count + 1, + ) + # Raise exception for error status codes response.raise_for_status() - except requests.ConnectionError: - raise Exception( - f"Unable to connect to the API server at {self.base_url}. Please check your internet connection or verify the service is available." + + # Log successful response + response_content_to_log = response.content + try: + # Attempt to parse JSON for prettier logging, fallback to raw content + response_content_to_log = response.json() + except json.JSONDecodeError: + pass # Keep as bytes/str if not JSON + + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, # Pass request details again for context in log + request_url=url, + response_status_code=response.status_code, + response_headers=dict(response.headers), + response_content=response_content_to_log ) - except requests.Timeout: - raise Exception( - f"Request timed out after {self.timeout} seconds. The server might be experiencing high load or the operation is taking longer than expected." + except requests.ConnectionError as e: + error_message = f"ConnectionError: {str(e)}" + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + error_message=error_message ) + # Only perform connectivity check if we've exhausted all retries + if retry_count >= self.max_retries: + # Check connectivity to determine if it's a local or API issue + connectivity = self._check_connectivity(self.base_url) + + if connectivity["is_local_issue"]: + raise LocalNetworkError( + "Unable to connect to the API server due to local network issues. " + "Please check your internet connection and try again." + ) from e + elif connectivity["is_api_issue"]: + raise ApiServerError( + f"The API server at {self.base_url} is currently unreachable. " + f"The service may be experiencing issues. Please try again later." + ) from e + + # If we haven't exhausted retries yet, retry the request + if retry_count < self.max_retries: + delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) + logging.warning( + f"Connection error: {str(e)}. " + f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" + ) + time.sleep(delay) + return self.request( + method=method, + path=path, + params=params, + data=data, + files=files, + headers=headers, + content_type=content_type, + multipart_parser=multipart_parser, + retry_count=retry_count + 1, + ) + + # If we've exhausted retries and didn't identify the specific issue, + # raise a generic exception + final_error_message = ( + f"Unable to connect to the API server after {self.max_retries} attempts. " + f"Please check your internet connection or try again later." + ) + request_logger.log_request_response( # Log final failure + operation_id=operation_id, + request_method=method, request_url=url, + error_message=final_error_message + ) + raise Exception(final_error_message) from e + + except requests.Timeout as e: + error_message = f"Timeout: {str(e)}" + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, request_url=url, + error_message=error_message + ) + # Retry timeouts if we haven't exhausted retries + if retry_count < self.max_retries: + delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) + logging.warning( + f"Request timed out. " + f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" + ) + time.sleep(delay) + return self.request( + method=method, + path=path, + params=params, + data=data, + files=files, + headers=headers, + content_type=content_type, + multipart_parser=multipart_parser, + retry_count=retry_count + 1, + ) + final_error_message = ( + f"Request timed out after {self.timeout} seconds and {self.max_retries} retry attempts. " + f"The server might be experiencing high load or the operation is taking longer than expected." + ) + request_logger.log_request_response( # Log final failure + operation_id=operation_id, + request_method=method, request_url=url, + error_message=final_error_message + ) + raise Exception(final_error_message) from e except requests.HTTPError as e: status_code = e.response.status_code if hasattr(e, "response") else None - error_message = f"HTTP Error: {str(e)}" + original_error_message = f"HTTP Error: {str(e)}" + error_content_for_log = None + if hasattr(e, "response") and e.response is not None: + error_content_for_log = e.response.content + try: + error_content_for_log = e.response.json() + except json.JSONDecodeError: + pass + + + # Try to extract detailed error message from JSON response for user display + # but log the full error content. + user_display_error_message = original_error_message - # Try to extract detailed error message from JSON response try: - if hasattr(e, "response") and e.response.content: + if hasattr(e, "response") and e.response is not None and e.response.content: error_json = e.response.json() if "error" in error_json and "message" in error_json["error"]: - error_message = f"API Error: {error_json['error']['message']}" + user_display_error_message = f"API Error: {error_json['error']['message']}" if "type" in error_json["error"]: - error_message += f" (Type: {error_json['error']['type']})" + user_display_error_message += f" (Type: {error_json['error']['type']})" + elif isinstance(error_json, dict): # Handle cases where error is just a JSON dict + user_display_error_message = f"API Error: {json.dumps(error_json)}" + else: # Non-dict JSON error + user_display_error_message = f"API Error: {str(error_json)}" + except json.JSONDecodeError: + # If not JSON, use the raw content if it's not too long, or a summary + if hasattr(e, "response") and e.response is not None and e.response.content: + raw_content = e.response.content.decode(errors='ignore') + if len(raw_content) < 200: # Arbitrary limit for display + user_display_error_message = f"API Error (raw): {raw_content}" else: - error_message = f"API Error: {error_json}" - except Exception as json_error: - # If we can't parse the JSON, fall back to the original error message - logging.debug( - f"[DEBUG] Failed to parse error response: {str(json_error)}" + user_display_error_message = f"API Error (raw, status {status_code})" + + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, request_url=url, + response_status_code=status_code, + response_headers=dict(e.response.headers) if hasattr(e, "response") and e.response is not None else None, + response_content=error_content_for_log, + error_message=original_error_message # Log the original exception string as error + ) + + logging.debug(f"[DEBUG] API Error: {user_display_error_message} (Status: {status_code})") + if hasattr(e, "response") and e.response is not None and e.response.content: + logging.debug(f"[DEBUG] Response content: {e.response.content}") + + # Retry if the status code is in our retry list and we haven't exhausted retries + if (status_code in self.retry_status_codes and + retry_count < self.max_retries): + + delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) + logging.warning( + f"HTTP error {status_code}. " + f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" + ) + time.sleep(delay) + return self.request( + method=method, + path=path, + params=params, + data=data, + files=files, + headers=headers, + content_type=content_type, + multipart_parser=multipart_parser, + retry_count=retry_count + 1, ) - logging.debug(f"[DEBUG] API Error: {error_message} (Status: {status_code})") - if hasattr(e, "response") and e.response.content: - logging.debug(f"[DEBUG] Response content: {e.response.content}") + # Specific error messages for common status codes for user display if status_code == 401: - error_message = "Unauthorized: Please login first to use this node." - if status_code == 402: - error_message = "Payment Required: Please add credits to your account to use this node." - if status_code == 409: - error_message = "There is a problem with your account. Please contact support@comfy.org. " - if status_code == 429: - error_message = "Rate Limit Exceeded: Please try again later." - raise Exception(error_message) + user_display_error_message = "Unauthorized: Please login first to use this node." + elif status_code == 402: + user_display_error_message = "Payment Required: Please add credits to your account to use this node." + elif status_code == 409: + user_display_error_message = "There is a problem with your account. Please contact support@comfy.org." + elif status_code == 429: + user_display_error_message = "Rate Limit Exceeded: Please try again later." + # else, user_display_error_message remains as parsed from response or original HTTPError string + + raise Exception(user_display_error_message) # Raise with the user-friendly message # Parse and return JSON response if response.content: @@ -336,26 +613,126 @@ class ApiClient: upload_url: str, file: io.BytesIO | str, content_type: str | None = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff_factor: float = 2.0, ): - """Upload a file to the API. Make sure the file has a filename equal to what the url expects. + """Upload a file to the API with retry logic. Args: upload_url: The URL to upload to file: Either a file path string, BytesIO object, or tuple of (file_path, filename) - mime_type: Optional mime type to set for the upload + content_type: Optional mime type to set for the upload + max_retries: Maximum number of retry attempts + retry_delay: Initial delay between retries in seconds + retry_backoff_factor: Multiplier for the delay after each retry """ headers = {} if content_type: headers["Content-Type"] = content_type + # Prepare the file data if isinstance(file, io.BytesIO): file.seek(0) # Ensure we're at the start of the file data = file.read() - return requests.put(upload_url, data=data, headers=headers) elif isinstance(file, str): with open(file, "rb") as f: data = f.read() - return requests.put(upload_url, data=data, headers=headers) + else: + raise ValueError("File must be either a BytesIO object or a file path string") + + # Try the upload with retries + last_exception = None + operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" # Simplified ID for uploads + + # Log initial attempt (without full file data for brevity) + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers, + request_data=f"[File data of type {content_type or 'unknown'}, size {len(data)} bytes]" + ) + + for retry_attempt in range(max_retries + 1): + try: + response = requests.put(upload_url, data=data, headers=headers) + response.raise_for_status() + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", request_url=upload_url, # For context + response_status_code=response.status_code, + response_headers=dict(response.headers), + response_content="File uploaded successfully." # Or response.text if available + ) + return response + + except (requests.ConnectionError, requests.Timeout, requests.HTTPError) as e: + last_exception = e + error_message_for_log = f"{type(e).__name__}: {str(e)}" + response_content_for_log = None + status_code_for_log = None + headers_for_log = None + + if hasattr(e, 'response') and e.response is not None: + status_code_for_log = e.response.status_code + headers_for_log = dict(e.response.headers) + try: + response_content_for_log = e.response.json() + except json.JSONDecodeError: + response_content_for_log = e.response.content + + + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", request_url=upload_url, + response_status_code=status_code_for_log, + response_headers=headers_for_log, + response_content=response_content_for_log, + error_message=error_message_for_log + ) + + if retry_attempt < max_retries: + delay = retry_delay * (retry_backoff_factor ** retry_attempt) + logging.warning( + f"File upload failed: {str(e)}. " + f"Retrying in {delay:.2f}s ({retry_attempt + 1}/{max_retries})" + ) + time.sleep(delay) + else: + break # Max retries reached + + # If we've exhausted all retries, determine the final error type and raise + final_error_message = f"Failed to upload file after {max_retries + 1} attempts. Error: {str(last_exception)}" + try: + # Check basic internet connectivity + check_response = requests.get("https://www.google.com", timeout=5.0, verify=True) # Assuming verify=True is desired + if check_response.status_code >= 500: # Google itself has an issue (rare) + final_error_message = (f"Failed to upload file. Internet connectivity check to Google failed " + f"(status {check_response.status_code}). Original error: {str(last_exception)}") + # Not raising LocalNetworkError here as Google itself might be down. + # If Google is reachable, the issue is likely with the upload server or a more specific local problem + # not caught by a simple Google ping (e.g., DNS for the specific upload URL, firewall). + # The original last_exception is probably most relevant. + + except (requests.RequestException, socket.error) as conn_check_exc: + # Could not reach Google, likely a local network issue + final_error_message = (f"Failed to upload file due to network connectivity issues " + f"(cannot reach Google: {str(conn_check_exc)}). " + f"Original upload error: {str(last_exception)}") + request_logger.log_request_response( # Log final failure reason + operation_id=operation_id, + request_method="PUT", request_url=upload_url, + error_message=final_error_message + ) + raise LocalNetworkError(final_error_message) from last_exception + + request_logger.log_request_response( # Log final failure reason if not LocalNetworkError + operation_id=operation_id, + request_method="PUT", request_url=upload_url, + error_message=final_error_message + ) + raise Exception(final_error_message) from last_exception class ApiEndpoint(Generic[T, R]): @@ -403,6 +780,9 @@ class SynchronousOperation(Generic[T, R]): verify_ssl: bool = True, content_type: str = "application/json", multipart_parser: Callable = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff_factor: float = 2.0, ): self.endpoint = endpoint self.request = request @@ -419,8 +799,12 @@ class SynchronousOperation(Generic[T, R]): self.files = files self.content_type = content_type self.multipart_parser = multipart_parser + self.max_retries = max_retries + self.retry_delay = retry_delay + self.retry_backoff_factor = retry_backoff_factor + def execute(self, client: Optional[ApiClient] = None) -> R: - """Execute the API operation using the provided client or create one""" + """Execute the API operation using the provided client or create one with retry support""" try: # Create client if not provided if client is None: @@ -430,6 +814,9 @@ class SynchronousOperation(Generic[T, R]): comfy_api_key=self.comfy_api_key, timeout=self.timeout, verify_ssl=self.verify_ssl, + max_retries=self.max_retries, + retry_delay=self.retry_delay, + retry_backoff_factor=self.retry_backoff_factor, ) # Convert request model to dict, but use None for EmptyRequest @@ -443,11 +830,6 @@ class SynchronousOperation(Generic[T, R]): if isinstance(value, Enum): request_dict[key] = value.value - if request_dict: - for key, value in request_dict.items(): - if isinstance(value, Enum): - request_dict[key] = value.value - # Debug log for request logging.debug( f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}" @@ -455,7 +837,7 @@ class SynchronousOperation(Generic[T, R]): logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}") logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}") - # Make the request + # Make the request with built-in retry resp = client.request( method=self.endpoint.method.value, path=self.endpoint.path, @@ -476,8 +858,18 @@ class SynchronousOperation(Generic[T, R]): # Parse and return the response return self._parse_response(resp) + except LocalNetworkError as e: + # Propagate specific network error types + logging.error(f"[ERROR] Local network error: {str(e)}") + raise + + except ApiServerError as e: + # Propagate API server errors + logging.error(f"[ERROR] API server error: {str(e)}") + raise + except Exception as e: - logging.error(f"[DEBUG] API Exception: {str(e)}") + logging.error(f"[ERROR] API Exception: {str(e)}") raise Exception(str(e)) def _parse_response(self, resp): @@ -511,12 +903,19 @@ class PollingOperation(Generic[T, R]): failed_statuses: list, status_extractor: Callable[[R], str], progress_extractor: Callable[[R], float] = None, + result_url_extractor: Callable[[R], str] = None, request: Optional[T] = None, api_base: str | None = None, auth_token: Optional[str] = None, comfy_api_key: Optional[str] = None, auth_kwargs: Optional[Dict[str,str]] = None, poll_interval: float = 5.0, + max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval) + max_retries: int = 3, # Max retries per individual API call + retry_delay: float = 1.0, + retry_backoff_factor: float = 2.0, + estimated_duration: Optional[float] = None, + node_id: Optional[str] = None, ): self.poll_endpoint = poll_endpoint self.request = request @@ -527,12 +926,19 @@ class PollingOperation(Generic[T, R]): self.auth_token = auth_kwargs.get("auth_token", self.auth_token) self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) self.poll_interval = poll_interval + self.max_poll_attempts = max_poll_attempts + self.max_retries = max_retries + self.retry_delay = retry_delay + self.retry_backoff_factor = retry_backoff_factor + self.estimated_duration = estimated_duration # Polling configuration self.status_extractor = status_extractor or ( lambda x: getattr(x, "status", None) ) self.progress_extractor = progress_extractor + self.result_url_extractor = result_url_extractor + self.node_id = node_id self.completed_statuses = completed_statuses self.failed_statuses = failed_statuses @@ -548,11 +954,46 @@ class PollingOperation(Generic[T, R]): base_url=self.api_base, auth_token=self.auth_token, comfy_api_key=self.comfy_api_key, + max_retries=self.max_retries, + retry_delay=self.retry_delay, + retry_backoff_factor=self.retry_backoff_factor, ) return self._poll_until_complete(client) + except LocalNetworkError as e: + # Provide clear message for local network issues + raise Exception( + f"Polling failed due to local network issues. Please check your internet connection. " + f"Details: {str(e)}" + ) from e + except ApiServerError as e: + # Provide clear message for API server issues + raise Exception( + f"Polling failed due to API server issues. The service may be experiencing problems. " + f"Please try again later. Details: {str(e)}" + ) from e except Exception as e: raise Exception(f"Error during polling: {str(e)}") + def _display_text_on_node(self, text: str): + """Sends text to the client which will be displayed on the node in the UI""" + if not self.node_id: + return + + PromptServer.instance.send_progress_text(text, self.node_id) + + def _display_time_progress_on_node(self, time_completed: int): + if not self.node_id: + return + + if self.estimated_duration is not None: + estimated_time_remaining = max( + 0, int(self.estimated_duration) - int(time_completed) + ) + message = f"Task in progress: {time_completed:.0f}s (~{estimated_time_remaining:.0f}s remaining)" + else: + message = f"Task in progress: {time_completed:.0f}s" + self._display_text_on_node(message) + def _check_task_status(self, response: R) -> TaskStatus: """Check task status using the status extractor function""" try: @@ -569,10 +1010,13 @@ class PollingOperation(Generic[T, R]): def _poll_until_complete(self, client: ApiClient) -> R: """Poll until the task is complete""" poll_count = 0 + consecutive_errors = 0 + max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors + if self.progress_extractor: progress = utils.ProgressBar(PROGRESS_BAR_MAX) - while True: + while poll_count < self.max_poll_attempts: try: poll_count += 1 logging.debug(f"[DEBUG] Polling attempt #{poll_count}") @@ -599,8 +1043,12 @@ class PollingOperation(Generic[T, R]): data=request_dict, ) + # Successfully got a response, reset consecutive error count + consecutive_errors = 0 + # Parse response response_obj = self.poll_endpoint.response_model.model_validate(resp) + # Check if task is complete status = self._check_task_status(response_obj) logging.debug(f"[DEBUG] Task Status: {status}") @@ -612,7 +1060,15 @@ class PollingOperation(Generic[T, R]): progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX) if status == TaskStatus.COMPLETED: - logging.debug("[DEBUG] Task completed successfully") + message = "Task completed successfully" + if self.result_url_extractor: + result_url = self.result_url_extractor(response_obj) + if result_url: + message = f"Result URL: {result_url}" + else: + message = "Task completed successfully!" + logging.debug(f"[DEBUG] {message}") + self._display_text_on_node(message) self.final_response = response_obj if self.progress_extractor: progress.update(100) @@ -628,8 +1084,43 @@ class PollingOperation(Generic[T, R]): logging.debug( f"[DEBUG] Waiting {self.poll_interval} seconds before next poll" ) + for i in range(int(self.poll_interval)): + time_completed = (poll_count * self.poll_interval) + i + self._display_time_progress_on_node(time_completed) + time.sleep(1) + + except (LocalNetworkError, ApiServerError) as e: + # For network-related errors, increment error count and potentially abort + consecutive_errors += 1 + if consecutive_errors >= max_consecutive_errors: + raise Exception( + f"Polling aborted after {consecutive_errors} consecutive network errors: {str(e)}" + ) from e + + # Log the error but continue polling + logging.warning( + f"Network error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " + f"Will retry in {self.poll_interval} seconds." + ) time.sleep(self.poll_interval) except Exception as e: + # For other errors, increment count and potentially abort + consecutive_errors += 1 + if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED: + raise Exception( + f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}" + ) from e + logging.error(f"[DEBUG] Polling error: {str(e)}") - raise Exception(f"Error while polling: {str(e)}") + logging.warning( + f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " + f"Will retry in {self.poll_interval} seconds." + ) + time.sleep(self.poll_interval) + + # If we've exhausted all polling attempts + raise Exception( + f"Polling timed out after {poll_count} attempts ({poll_count * self.poll_interval} seconds). " + f"The operation may still be running on the server but is taking longer than expected." + ) diff --git a/comfy_api_nodes/apis/request_logger.py b/comfy_api_nodes/apis/request_logger.py new file mode 100644 index 000000000..93517ede9 --- /dev/null +++ b/comfy_api_nodes/apis/request_logger.py @@ -0,0 +1,125 @@ +import os +import datetime +import json +import logging +import folder_paths + +# Get the logger instance +logger = logging.getLogger(__name__) + +def get_log_directory(): + """ + Ensures the API log directory exists within ComfyUI's temp directory + and returns its path. + """ + base_temp_dir = folder_paths.get_temp_directory() + log_dir = os.path.join(base_temp_dir, "api_logs") + try: + os.makedirs(log_dir, exist_ok=True) + except Exception as e: + logger.error(f"Error creating API log directory {log_dir}: {e}") + # Fallback to base temp directory if sub-directory creation fails + return base_temp_dir + return log_dir + +def _format_data_for_logging(data): + """Helper to format data (dict, str, bytes) for logging.""" + if isinstance(data, bytes): + try: + return data.decode('utf-8') # Try to decode as text + except UnicodeDecodeError: + return f"[Binary data of length {len(data)} bytes]" + elif isinstance(data, (dict, list)): + try: + return json.dumps(data, indent=2, ensure_ascii=False) + except TypeError: + return str(data) # Fallback for non-serializable objects + return str(data) + +def log_request_response( + operation_id: str, + request_method: str, + request_url: str, + request_headers: dict | None = None, + request_params: dict | None = None, + request_data: any = None, + response_status_code: int | None = None, + response_headers: dict | None = None, + response_content: any = None, + error_message: str | None = None +): + """ + Logs API request and response details to a file in the temp/api_logs directory. + """ + log_dir = get_log_directory() + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") + filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log" + filepath = os.path.join(log_dir, filename) + + log_content = [] + + log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}") + log_content.append(f"Operation ID: {operation_id}") + log_content.append("-" * 30 + " REQUEST " + "-" * 30) + log_content.append(f"Method: {request_method}") + log_content.append(f"URL: {request_url}") + if request_headers: + log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}") + if request_params: + log_content.append(f"Params:\n{_format_data_for_logging(request_params)}") + if request_data: + log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}") + + log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30) + if response_status_code is not None: + log_content.append(f"Status Code: {response_status_code}") + if response_headers: + log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}") + if response_content: + log_content.append(f"Content:\n{_format_data_for_logging(response_content)}") + if error_message: + log_content.append(f"Error:\n{error_message}") + + try: + with open(filepath, "w", encoding="utf-8") as f: + f.write("\n".join(log_content)) + logger.debug(f"API log saved to: {filepath}") + except Exception as e: + logger.error(f"Error writing API log to {filepath}: {e}") + +if __name__ == '__main__': + # Example usage (for testing the logger directly) + logger.setLevel(logging.DEBUG) + # Mock folder_paths for direct execution if not running within ComfyUI full context + if not hasattr(folder_paths, 'get_temp_directory'): + class MockFolderPaths: + def get_temp_directory(self): + # Create a local temp dir for testing if needed + p = os.path.join(os.path.dirname(__file__), 'temp_test_logs') + os.makedirs(p, exist_ok=True) + return p + folder_paths = MockFolderPaths() + + log_request_response( + operation_id="test_operation_get", + request_method="GET", + request_url="https://api.example.com/test", + request_headers={"Authorization": "Bearer testtoken"}, + request_params={"param1": "value1"}, + response_status_code=200, + response_content={"message": "Success!"} + ) + log_request_response( + operation_id="test_operation_post_error", + request_method="POST", + request_url="https://api.example.com/submit", + request_data={"key": "value", "nested": {"num": 123}}, + error_message="Connection timed out" + ) + log_request_response( + operation_id="test_binary_response", + request_method="GET", + request_url="https://api.example.com/image.png", + response_status_code=200, + response_content=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR...' # Sample binary data + ) diff --git a/comfy_api_nodes/apis/rodin_api.py b/comfy_api_nodes/apis/rodin_api.py new file mode 100644 index 000000000..b0cf171fa --- /dev/null +++ b/comfy_api_nodes/apis/rodin_api.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from enum import Enum +from typing import Optional, List +from pydantic import BaseModel, Field + + +class Rodin3DGenerateRequest(BaseModel): + seed: int = Field(..., description="seed_") + tier: str = Field(..., description="Tier of generation.") + material: str = Field(..., description="The material type.") + quality: str = Field(..., description="The generation quality of the mesh.") + mesh_mode: str = Field(..., description="It controls the type of faces of generated models.") + +class GenerateJobsData(BaseModel): + uuids: List[str] = Field(..., description="str LIST") + subscription_key: str = Field(..., description="subscription key") + +class Rodin3DGenerateResponse(BaseModel): + message: Optional[str] = Field(None, description="Return message.") + prompt: Optional[str] = Field(None, description="Generated Prompt from image.") + submit_time: Optional[str] = Field(None, description="Submit Time") + uuid: Optional[str] = Field(None, description="Task str") + jobs: Optional[GenerateJobsData] = Field(None, description="Details of jobs") + +class JobStatus(str, Enum): + """ + Status for jobs + """ + Done = "Done" + Failed = "Failed" + Generating = "Generating" + Waiting = "Waiting" + +class Rodin3DCheckStatusRequest(BaseModel): + subscription_key: str = Field(..., description="subscription from generate endpoint") + +class JobItem(BaseModel): + uuid: str = Field(..., description="uuid") + status: JobStatus = Field(...,description="Status Currently") + +class Rodin3DCheckStatusResponse(BaseModel): + jobs: List[JobItem] = Field(..., description="Job status List") + +class Rodin3DDownloadRequest(BaseModel): + task_uuid: str = Field(..., description="Task str") + +class RodinResourceItem(BaseModel): + url: str = Field(..., description="Download Url") + name: str = Field(..., description="File name with ext") + +class Rodin3DDownloadResponse(BaseModel): + list: List[RodinResourceItem] = Field(..., description="Source List") + + + + diff --git a/comfy_api_nodes/apis/tripo_api.py b/comfy_api_nodes/apis/tripo_api.py new file mode 100644 index 000000000..626e8d277 --- /dev/null +++ b/comfy_api_nodes/apis/tripo_api.py @@ -0,0 +1,275 @@ +from __future__ import annotations +from comfy_api_nodes.apis import ( + TripoModelVersion, + TripoTextureQuality, +) +from enum import Enum +from typing import Optional, List, Dict, Any, Union + +from pydantic import BaseModel, Field, RootModel + +class TripoStyle(str, Enum): + PERSON_TO_CARTOON = "person:person2cartoon" + ANIMAL_VENOM = "animal:venom" + OBJECT_CLAY = "object:clay" + OBJECT_STEAMPUNK = "object:steampunk" + OBJECT_CHRISTMAS = "object:christmas" + OBJECT_BARBIE = "object:barbie" + GOLD = "gold" + ANCIENT_BRONZE = "ancient_bronze" + NONE = "None" + +class TripoTaskType(str, Enum): + TEXT_TO_MODEL = "text_to_model" + IMAGE_TO_MODEL = "image_to_model" + MULTIVIEW_TO_MODEL = "multiview_to_model" + TEXTURE_MODEL = "texture_model" + REFINE_MODEL = "refine_model" + ANIMATE_PRERIGCHECK = "animate_prerigcheck" + ANIMATE_RIG = "animate_rig" + ANIMATE_RETARGET = "animate_retarget" + STYLIZE_MODEL = "stylize_model" + CONVERT_MODEL = "convert_model" + +class TripoTextureAlignment(str, Enum): + ORIGINAL_IMAGE = "original_image" + GEOMETRY = "geometry" + +class TripoOrientation(str, Enum): + ALIGN_IMAGE = "align_image" + DEFAULT = "default" + +class TripoOutFormat(str, Enum): + GLB = "glb" + FBX = "fbx" + +class TripoTopology(str, Enum): + BIP = "bip" + QUAD = "quad" + +class TripoSpec(str, Enum): + MIXAMO = "mixamo" + TRIPO = "tripo" + +class TripoAnimation(str, Enum): + IDLE = "preset:idle" + WALK = "preset:walk" + CLIMB = "preset:climb" + JUMP = "preset:jump" + RUN = "preset:run" + SLASH = "preset:slash" + SHOOT = "preset:shoot" + HURT = "preset:hurt" + FALL = "preset:fall" + TURN = "preset:turn" + +class TripoStylizeStyle(str, Enum): + LEGO = "lego" + VOXEL = "voxel" + VORONOI = "voronoi" + MINECRAFT = "minecraft" + +class TripoConvertFormat(str, Enum): + GLTF = "GLTF" + USDZ = "USDZ" + FBX = "FBX" + OBJ = "OBJ" + STL = "STL" + _3MF = "3MF" + +class TripoTextureFormat(str, Enum): + BMP = "BMP" + DPX = "DPX" + HDR = "HDR" + JPEG = "JPEG" + OPEN_EXR = "OPEN_EXR" + PNG = "PNG" + TARGA = "TARGA" + TIFF = "TIFF" + WEBP = "WEBP" + +class TripoTaskStatus(str, Enum): + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + CANCELLED = "cancelled" + UNKNOWN = "unknown" + BANNED = "banned" + EXPIRED = "expired" + +class TripoFileTokenReference(BaseModel): + type: Optional[str] = Field(None, description='The type of the reference') + file_token: str + +class TripoUrlReference(BaseModel): + type: Optional[str] = Field(None, description='The type of the reference') + url: str + +class TripoObjectStorage(BaseModel): + bucket: str + key: str + +class TripoObjectReference(BaseModel): + type: str + object: TripoObjectStorage + +class TripoFileEmptyReference(BaseModel): + pass + +class TripoFileReference(RootModel): + root: Union[TripoFileTokenReference, TripoUrlReference, TripoObjectReference, TripoFileEmptyReference] + +class TripoGetStsTokenRequest(BaseModel): + format: str = Field(..., description='The format of the image') + +class TripoTextToModelRequest(BaseModel): + type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task') + prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024) + negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024) + model_version: Optional[TripoModelVersion] = TripoModelVersion.V2_5 + face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to') + texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model') + pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model') + image_seed: Optional[int] = Field(None, description='The seed for the text') + model_seed: Optional[int] = Field(None, description='The seed for the model') + texture_seed: Optional[int] = Field(None, description='The seed for the texture') + texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard + style: Optional[TripoStyle] = None + auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') + quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') + +class TripoImageToModelRequest(BaseModel): + type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description='Type of task') + file: TripoFileReference = Field(..., description='The file reference to convert to a model') + model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation') + face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to') + texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model') + pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model') + model_seed: Optional[int] = Field(None, description='The seed for the model') + texture_seed: Optional[int] = Field(None, description='The seed for the texture') + texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard + texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method') + style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model') + auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') + orientation: Optional[TripoOrientation] = TripoOrientation.DEFAULT + quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') + +class TripoMultiviewToModelRequest(BaseModel): + type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL + files: List[TripoFileReference] = Field(..., description='The file references to convert to a model') + model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation') + orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection') + face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to') + texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model') + pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model') + model_seed: Optional[int] = Field(None, description='The seed for the model') + texture_seed: Optional[int] = Field(None, description='The seed for the texture') + texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard + texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE + auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') + orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model') + quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') + +class TripoTextureModelRequest(BaseModel): + type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description='Type of task') + original_model_task_id: str = Field(..., description='The task ID of the original model') + texture: Optional[bool] = Field(True, description='Whether to apply texture to the model') + pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the model') + model_seed: Optional[int] = Field(None, description='The seed for the model') + texture_seed: Optional[int] = Field(None, description='The seed for the texture') + texture_quality: Optional[TripoTextureQuality] = Field(None, description='The quality of the texture') + texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method') + +class TripoRefineModelRequest(BaseModel): + type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description='Type of task') + draft_model_task_id: str = Field(..., description='The task ID of the draft model') + +class TripoAnimatePrerigcheckRequest(BaseModel): + type: TripoTaskType = Field(TripoTaskType.ANIMATE_PRERIGCHECK, description='Type of task') + original_model_task_id: str = Field(..., description='The task ID of the original model') + +class TripoAnimateRigRequest(BaseModel): + type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description='Type of task') + original_model_task_id: str = Field(..., description='The task ID of the original model') + out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format') + spec: Optional[TripoSpec] = Field(TripoSpec.TRIPO, description='The specification for rigging') + +class TripoAnimateRetargetRequest(BaseModel): + type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description='Type of task') + original_model_task_id: str = Field(..., description='The task ID of the original model') + animation: TripoAnimation = Field(..., description='The animation to apply') + out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format') + bake_animation: Optional[bool] = Field(True, description='Whether to bake the animation') + +class TripoStylizeModelRequest(BaseModel): + type: TripoTaskType = Field(TripoTaskType.STYLIZE_MODEL, description='Type of task') + style: TripoStylizeStyle = Field(..., description='The style to apply to the model') + original_model_task_id: str = Field(..., description='The task ID of the original model') + block_size: Optional[int] = Field(80, description='The block size for stylization') + +class TripoConvertModelRequest(BaseModel): + type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task') + format: TripoConvertFormat = Field(..., description='The format to convert to') + original_model_task_id: str = Field(..., description='The task ID of the original model') + quad: Optional[bool] = Field(False, description='Whether to apply quad to the model') + force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry') + face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to') + flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model') + flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom') + texture_size: Optional[int] = Field(4096, description='The size of the texture') + texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture') + pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom') + +class TripoTaskRequest(RootModel): + root: Union[ + TripoTextToModelRequest, + TripoImageToModelRequest, + TripoMultiviewToModelRequest, + TripoTextureModelRequest, + TripoRefineModelRequest, + TripoAnimatePrerigcheckRequest, + TripoAnimateRigRequest, + TripoAnimateRetargetRequest, + TripoStylizeModelRequest, + TripoConvertModelRequest + ] + +class TripoTaskOutput(BaseModel): + model: Optional[str] = Field(None, description='URL to the model') + base_model: Optional[str] = Field(None, description='URL to the base model') + pbr_model: Optional[str] = Field(None, description='URL to the PBR model') + rendered_image: Optional[str] = Field(None, description='URL to the rendered image') + riggable: Optional[bool] = Field(None, description='Whether the model is riggable') + +class TripoTask(BaseModel): + task_id: str = Field(..., description='The task ID') + type: Optional[str] = Field(None, description='The type of task') + status: Optional[TripoTaskStatus] = Field(None, description='The status of the task') + input: Optional[Dict[str, Any]] = Field(None, description='The input parameters for the task') + output: Optional[TripoTaskOutput] = Field(None, description='The output of the task') + progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100) + create_time: Optional[int] = Field(None, description='The creation time of the task') + running_left_time: Optional[int] = Field(None, description='The estimated time left for the task') + queue_position: Optional[int] = Field(None, description='The position in the queue') + +class TripoTaskResponse(BaseModel): + code: int = Field(0, description='The response code') + data: TripoTask = Field(..., description='The task data') + +class TripoGeneralResponse(BaseModel): + code: int = Field(0, description='The response code') + data: Dict[str, str] = Field(..., description='The task ID data') + +class TripoBalanceData(BaseModel): + balance: float = Field(..., description='The account balance') + frozen: float = Field(..., description='The frozen balance') + +class TripoBalanceResponse(BaseModel): + code: int = Field(0, description='The response code') + data: TripoBalanceData = Field(..., description='The balance data') + +class TripoErrorResponse(BaseModel): + code: int = Field(..., description='The error code') + message: str = Field(..., description='The error message') + suggestion: str = Field(..., description='The suggestion for fixing the error') diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 66ef1b391..d93fbd778 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,5 +1,6 @@ import io from inspect import cleandoc +from typing import Union, Optional from comfy.comfy_types.node_typing import IO, ComfyNodeABC from comfy_api_nodes.apis.bfl_api import ( BFLStatus, @@ -8,6 +9,7 @@ from comfy_api_nodes.apis.bfl_api import ( BFLFluxCannyImageRequest, BFLFluxDepthImageRequest, BFLFluxProGenerateRequest, + BFLFluxKontextProGenerateRequest, BFLFluxProUltraGenerateRequest, BFLFluxProGenerateResponse, ) @@ -30,6 +32,7 @@ import requests import torch import base64 import time +from server import PromptServer def convert_mask_to_image(mask: torch.Tensor): @@ -42,14 +45,19 @@ def convert_mask_to_image(mask: torch.Tensor): def handle_bfl_synchronous_operation( - operation: SynchronousOperation, timeout_bfl_calls=360 + operation: SynchronousOperation, + timeout_bfl_calls=360, + node_id: Union[str, None] = None, ): response_api: BFLFluxProGenerateResponse = operation.execute() return _poll_until_generated( - response_api.polling_url, timeout=timeout_bfl_calls + response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id ) -def _poll_until_generated(polling_url: str, timeout=360): + +def _poll_until_generated( + polling_url: str, timeout=360, node_id: Union[str, None] = None +): # used bfl-comfy-nodes to verify code implementation: # https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main start_time = time.time() @@ -61,11 +69,21 @@ def _poll_until_generated(polling_url: str, timeout=360): request = requests.Request(method=HttpMethod.GET, url=polling_url) # NOTE: should True loop be replaced with checking if workflow has been interrupted? while True: + if node_id: + time_elapsed = time.time() - start_time + PromptServer.instance.send_progress_text( + f"Generating ({time_elapsed:.0f}s)", node_id + ) + response = requests.Session().send(request.prepare()) if response.status_code == 200: result = response.json() if result["status"] == BFLStatus.ready: img_url = result["result"]["sample"] + if node_id: + PromptServer.instance.send_progress_text( + f"Result URL: {img_url}", node_id + ) img_response = requests.get(img_url) return process_image_response(img_response) elif result["status"] in [ @@ -180,6 +198,7 @@ class FluxProUltraImageNode(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -212,6 +231,7 @@ class FluxProUltraImageNode(ComfyNodeABC): seed=0, image_prompt=None, image_prompt_strength=0.1, + unique_id: Union[str, None] = None, **kwargs, ): if image_prompt is None: @@ -246,10 +266,149 @@ class FluxProUltraImageNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation) + output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) +class FluxKontextProImageNode(ComfyNodeABC): + """ + Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio. + """ + + MINIMUM_RATIO = 1 / 4 + MAXIMUM_RATIO = 4 / 1 + MINIMUM_RATIO_STR = "1:4" + MAXIMUM_RATIO_STR = "4:1" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation - specify what and how to edit.", + }, + ), + "aspect_ratio": ( + IO.STRING, + { + "default": "16:9", + "tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.", + }, + ), + "guidance": ( + IO.FLOAT, + { + "default": 3.0, + "min": 0.1, + "max": 99.0, + "step": 0.1, + "tooltip": "Guidance strength for the image generation process" + }, + ), + "steps": ( + IO.INT, + { + "default": 50, + "min": 1, + "max": 150, + "tooltip": "Number of steps for the image generation process" + }, + ), + "seed": ( + IO.INT, + { + "default": 1234, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise.", + }, + ), + "prompt_upsampling": ( + IO.BOOLEAN, + { + "default": False, + "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + }, + ), + }, + "optional": { + "input_image": (IO.IMAGE,), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/BFL" + + BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate" + + def api_call( + self, + prompt: str, + aspect_ratio: str, + guidance: float, + steps: int, + input_image: Optional[torch.Tensor]=None, + seed=0, + prompt_upsampling=False, + unique_id: Union[str, None] = None, + **kwargs, + ): + aspect_ratio = validate_aspect_ratio( + aspect_ratio, + minimum_ratio=self.MINIMUM_RATIO, + maximum_ratio=self.MAXIMUM_RATIO, + minimum_ratio_str=self.MINIMUM_RATIO_STR, + maximum_ratio_str=self.MAXIMUM_RATIO_STR, + ) + if input_image is None: + validate_string(prompt, strip_whitespace=False) + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=self.BFL_PATH, + method=HttpMethod.POST, + request_model=BFLFluxKontextProGenerateRequest, + response_model=BFLFluxProGenerateResponse, + ), + request=BFLFluxKontextProGenerateRequest( + prompt=prompt, + prompt_upsampling=prompt_upsampling, + guidance=round(guidance, 1), + steps=steps, + seed=seed, + aspect_ratio=aspect_ratio, + input_image=( + input_image + if input_image is None + else convert_image_to_base64(input_image) + ) + ), + auth_kwargs=kwargs, + ) + output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + return (output_image,) + + +class FluxKontextMaxImageNode(FluxKontextProImageNode): + """ + Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio. + """ + + DESCRIPTION = cleandoc(__doc__ or "") + BFL_PATH = "/proxy/bfl/flux-kontext-max/generate" + class FluxProImageNode(ComfyNodeABC): """ @@ -320,6 +479,7 @@ class FluxProImageNode(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -338,6 +498,7 @@ class FluxProImageNode(ComfyNodeABC): seed=0, image_prompt=None, # image_prompt_strength=0.1, + unique_id: Union[str, None] = None, **kwargs, ): image_prompt = ( @@ -363,7 +524,7 @@ class FluxProImageNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation) + output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -457,11 +618,11 @@ class FluxProExpandNode(ComfyNodeABC): }, ), }, - "optional": { - }, + "optional": {}, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -483,6 +644,7 @@ class FluxProExpandNode(ComfyNodeABC): steps: int, guidance: float, seed=0, + unique_id: Union[str, None] = None, **kwargs, ): image = convert_image_to_base64(image) @@ -508,7 +670,7 @@ class FluxProExpandNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation) + output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -568,11 +730,11 @@ class FluxProFillNode(ComfyNodeABC): }, ), }, - "optional": { - }, + "optional": {}, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -591,13 +753,14 @@ class FluxProFillNode(ComfyNodeABC): steps: int, guidance: float, seed=0, + unique_id: Union[str, None] = None, **kwargs, ): # prepare mask mask = resize_mask_to_image(mask, image) mask = convert_image_to_base64(convert_mask_to_image(mask)) # make sure image will have alpha channel removed - image = convert_image_to_base64(image[:,:,:,:3]) + image = convert_image_to_base64(image[:, :, :, :3]) operation = SynchronousOperation( endpoint=ApiEndpoint( @@ -617,7 +780,7 @@ class FluxProFillNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation) + output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -702,11 +865,11 @@ class FluxProCannyNode(ComfyNodeABC): }, ), }, - "optional": { - }, + "optional": {}, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -727,9 +890,10 @@ class FluxProCannyNode(ComfyNodeABC): steps: int, guidance: float, seed=0, + unique_id: Union[str, None] = None, **kwargs, ): - control_image = convert_image_to_base64(control_image[:,:,:,:3]) + control_image = convert_image_to_base64(control_image[:, :, :, :3]) preprocessed_image = None # scale canny threshold between 0-500, to match BFL's API @@ -765,7 +929,7 @@ class FluxProCannyNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation) + output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -830,11 +994,11 @@ class FluxProDepthNode(ComfyNodeABC): }, ), }, - "optional": { - }, + "optional": {}, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -853,6 +1017,7 @@ class FluxProDepthNode(ComfyNodeABC): steps: int, guidance: float, seed=0, + unique_id: Union[str, None] = None, **kwargs, ): control_image = convert_image_to_base64(control_image[:,:,:,:3]) @@ -880,7 +1045,7 @@ class FluxProDepthNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation) + output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -889,6 +1054,8 @@ class FluxProDepthNode(ComfyNodeABC): NODE_CLASS_MAPPINGS = { "FluxProUltraImageNode": FluxProUltraImageNode, # "FluxProImageNode": FluxProImageNode, + "FluxKontextProImageNode": FluxKontextProImageNode, + "FluxKontextMaxImageNode": FluxKontextMaxImageNode, "FluxProExpandNode": FluxProExpandNode, "FluxProFillNode": FluxProFillNode, "FluxProCannyNode": FluxProCannyNode, @@ -899,6 +1066,8 @@ NODE_CLASS_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = { "FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image", # "FluxProImageNode": "Flux 1.1 [pro] Image", + "FluxKontextProImageNode": "Flux.1 Kontext [pro] Image", + "FluxKontextMaxImageNode": "Flux.1 Kontext [max] Image", "FluxProExpandNode": "Flux.1 Expand Image", "FluxProFillNode": "Flux.1 Fill Image", "FluxProCannyNode": "Flux.1 Canny Control Image", diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py new file mode 100644 index 000000000..ae7b04846 --- /dev/null +++ b/comfy_api_nodes/nodes_gemini.py @@ -0,0 +1,446 @@ +""" +API Nodes for Gemini Multimodal LLM Usage via Remote API +See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference +""" + +import os +from enum import Enum +from typing import Optional, Literal + +import torch + +import folder_paths +from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict +from server import PromptServer +from comfy_api_nodes.apis import ( + GeminiContent, + GeminiGenerateContentRequest, + GeminiGenerateContentResponse, + GeminiInlineData, + GeminiPart, + GeminiMimeType, +) +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, +) +from comfy_api_nodes.apinode_utils import ( + validate_string, + audio_to_base64_string, + video_to_base64_string, + tensor_to_base64_string, +) + + +GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" +GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB + + +class GeminiModel(str, Enum): + """ + Gemini Model Names allowed by comfy-api + """ + + gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06" + gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17" + + +def get_gemini_endpoint( + model: GeminiModel, +) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]: + """ + Get the API endpoint for a given Gemini model. + + Args: + model: The Gemini model to use, either as enum or string value. + + Returns: + ApiEndpoint configured for the specific Gemini model. + """ + if isinstance(model, str): + model = GeminiModel(model) + return ApiEndpoint( + path=f"{GEMINI_BASE_ENDPOINT}/{model.value}", + method=HttpMethod.POST, + request_model=GeminiGenerateContentRequest, + response_model=GeminiGenerateContentResponse, + ) + + +class GeminiNode(ComfyNodeABC): + """ + Node to generate text responses from a Gemini model. + + This node allows users to interact with Google's Gemini AI models, providing + multimodal inputs (text, images, audio, video, files) to generate coherent + text responses. The node works with the latest Gemini models, handling the + API communication and response parsing. + """ + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.", + }, + ), + "model": ( + IO.COMBO, + { + "tooltip": "The Gemini model to use for generating responses.", + "options": [model.value for model in GeminiModel], + "default": GeminiModel.gemini_2_5_pro_preview_05_06.value, + }, + ), + "seed": ( + IO.INT, + { + "default": 42, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.", + }, + ), + }, + "optional": { + "images": ( + IO.IMAGE, + { + "default": None, + "tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.", + }, + ), + "audio": ( + IO.AUDIO, + { + "tooltip": "Optional audio to use as context for the model.", + "default": None, + }, + ), + "video": ( + IO.VIDEO, + { + "tooltip": "Optional video to use as context for the model.", + "default": None, + }, + ), + "files": ( + "GEMINI_INPUT_FILES", + { + "default": None, + "tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.", + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses." + RETURN_TYPES = ("STRING",) + FUNCTION = "api_call" + CATEGORY = "api node/text/Gemini" + API_NODE = True + + def get_parts_from_response( + self, response: GeminiGenerateContentResponse + ) -> list[GeminiPart]: + """ + Extract all parts from the Gemini API response. + + Args: + response: The API response from Gemini. + + Returns: + List of response parts from the first candidate. + """ + return response.candidates[0].content.parts + + def get_parts_by_type( + self, response: GeminiGenerateContentResponse, part_type: Literal["text"] | str + ) -> list[GeminiPart]: + """ + Filter response parts by their type. + + Args: + response: The API response from Gemini. + part_type: Type of parts to extract ("text" or a MIME type). + + Returns: + List of response parts matching the requested type. + """ + parts = [] + for part in self.get_parts_from_response(response): + if part_type == "text" and hasattr(part, "text") and part.text: + parts.append(part) + elif ( + hasattr(part, "inlineData") + and part.inlineData + and part.inlineData.mimeType == part_type + ): + parts.append(part) + # Skip parts that don't match the requested type + return parts + + def get_text_from_response(self, response: GeminiGenerateContentResponse) -> str: + """ + Extract and concatenate all text parts from the response. + + Args: + response: The API response from Gemini. + + Returns: + Combined text from all text parts in the response. + """ + parts = self.get_parts_by_type(response, "text") + return "\n".join([part.text for part in parts]) + + def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]: + """ + Convert video input to Gemini API compatible parts. + + Args: + video_input: Video tensor from ComfyUI. + **kwargs: Additional arguments to pass to the conversion function. + + Returns: + List of GeminiPart objects containing the encoded video. + """ + from comfy_api.util import VideoContainer, VideoCodec + base_64_string = video_to_base64_string( + video_input, + container_format=VideoContainer.MP4, + codec=VideoCodec.H264 + ) + return [ + GeminiPart( + inlineData=GeminiInlineData( + mimeType=GeminiMimeType.video_mp4, + data=base_64_string, + ) + ) + ] + + def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]: + """ + Convert audio input to Gemini API compatible parts. + + Args: + audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate. + + Returns: + List of GeminiPart objects containing the encoded audio. + """ + audio_parts: list[GeminiPart] = [] + for batch_index in range(audio_input["waveform"].shape[0]): + # Recreate an IO.AUDIO object for the given batch dimension index + audio_at_index = { + "waveform": audio_input["waveform"][batch_index].unsqueeze(0), + "sample_rate": audio_input["sample_rate"], + } + # Convert to MP3 format for compatibility with Gemini API + audio_bytes = audio_to_base64_string( + audio_at_index, + container_format="mp3", + codec_name="libmp3lame", + ) + audio_parts.append( + GeminiPart( + inlineData=GeminiInlineData( + mimeType=GeminiMimeType.audio_mp3, + data=audio_bytes, + ) + ) + ) + return audio_parts + + def create_image_parts(self, image_input: torch.Tensor) -> list[GeminiPart]: + """ + Convert image tensor input to Gemini API compatible parts. + + Args: + image_input: Batch of image tensors from ComfyUI. + + Returns: + List of GeminiPart objects containing the encoded images. + """ + image_parts: list[GeminiPart] = [] + for image_index in range(image_input.shape[0]): + image_as_b64 = tensor_to_base64_string( + image_input[image_index].unsqueeze(0) + ) + image_parts.append( + GeminiPart( + inlineData=GeminiInlineData( + mimeType=GeminiMimeType.image_png, + data=image_as_b64, + ) + ) + ) + return image_parts + + def create_text_part(self, text: str) -> GeminiPart: + """ + Create a text part for the Gemini API request. + + Args: + text: The text content to include in the request. + + Returns: + A GeminiPart object with the text content. + """ + return GeminiPart(text=text) + + def api_call( + self, + prompt: str, + model: GeminiModel, + images: Optional[IO.IMAGE] = None, + audio: Optional[IO.AUDIO] = None, + video: Optional[IO.VIDEO] = None, + files: Optional[list[GeminiPart]] = None, + unique_id: Optional[str] = None, + **kwargs, + ) -> tuple[str]: + # Validate inputs + validate_string(prompt, strip_whitespace=False) + + # Create parts list with text prompt as the first part + parts: list[GeminiPart] = [self.create_text_part(prompt)] + + # Add other modal parts + if images is not None: + image_parts = self.create_image_parts(images) + parts.extend(image_parts) + if audio is not None: + parts.extend(self.create_audio_parts(audio)) + if video is not None: + parts.extend(self.create_video_parts(video)) + if files is not None: + parts.extend(files) + + # Create response + response = SynchronousOperation( + endpoint=get_gemini_endpoint(model), + request=GeminiGenerateContentRequest( + contents=[ + GeminiContent( + role="user", + parts=parts, + ) + ] + ), + auth_kwargs=kwargs, + ).execute() + + # Get result output + output_text = self.get_text_from_response(response) + if unique_id and output_text: + PromptServer.instance.send_progress_text(output_text, node_id=unique_id) + + return (output_text or "Empty response from Gemini model...",) + + +class GeminiInputFiles(ComfyNodeABC): + """ + Loads and formats input files for use with the Gemini API. + + This node allows users to include text (.txt) and PDF (.pdf) files as input + context for the Gemini model. Files are converted to the appropriate format + required by the API and can be chained together to include multiple files + in a single request. + """ + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + """ + For details about the supported file input types, see: + https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference + """ + input_dir = folder_paths.get_input_directory() + input_files = [ + f + for f in os.scandir(input_dir) + if f.is_file() + and (f.name.endswith(".txt") or f.name.endswith(".pdf")) + and f.stat().st_size < GEMINI_MAX_INPUT_FILE_SIZE + ] + input_files = sorted(input_files, key=lambda x: x.name) + input_files = [f.name for f in input_files] + return { + "required": { + "file": ( + IO.COMBO, + { + "tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.", + "options": input_files, + "default": input_files[0] if input_files else None, + }, + ), + }, + "optional": { + "GEMINI_INPUT_FILES": ( + "GEMINI_INPUT_FILES", + { + "tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.", + "default": None, + }, + ), + }, + } + + DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes." + RETURN_TYPES = ("GEMINI_INPUT_FILES",) + FUNCTION = "prepare_files" + CATEGORY = "api node/text/Gemini" + + def create_file_part(self, file_path: str) -> GeminiPart: + mime_type = ( + GeminiMimeType.pdf + if file_path.endswith(".pdf") + else GeminiMimeType.text_plain + ) + # Use base64 string directly, not the data URI + with open(file_path, "rb") as f: + file_content = f.read() + import base64 + base64_str = base64.b64encode(file_content).decode("utf-8") + + return GeminiPart( + inlineData=GeminiInlineData( + mimeType=mime_type, + data=base64_str, + ) + ) + + def prepare_files( + self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = [] + ) -> tuple[list[GeminiPart]]: + """ + Loads and formats input files for Gemini API. + """ + file_path = folder_paths.get_annotated_filepath(file) + input_file_content = self.create_file_part(file_path) + files = [input_file_content] + GEMINI_INPUT_FILES + return (files,) + + +NODE_CLASS_MAPPINGS = { + "GeminiNode": GeminiNode, + "GeminiInputFiles": GeminiInputFiles, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "GeminiNode": "Google Gemini", + "GeminiInputFiles": "Gemini Input Files", +} diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index d25468b17..b8487355f 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -23,6 +23,7 @@ from comfy_api_nodes.apinode_utils import ( bytesio_to_image_tensor, resize_mask_to_image, ) +from server import PromptServer V1_V1_RES_MAP = { "Auto":"AUTO", @@ -232,6 +233,19 @@ def download_and_process_images(image_urls): return stacked_tensors +def display_image_urls_on_node(image_urls, node_id): + if node_id and image_urls: + if len(image_urls) == 1: + PromptServer.instance.send_progress_text( + f"Generated Image URL:\n{image_urls[0]}", node_id + ) + else: + urls_text = "Generated Image URLs:\n" + "\n".join( + f"{i+1}. {url}" for i, url in enumerate(image_urls) + ) + PromptServer.instance.send_progress_text(urls_text, node_id) + + class IdeogramV1(ComfyNodeABC): """ Generates images using the Ideogram V1 model. @@ -304,12 +318,13 @@ class IdeogramV1(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } RETURN_TYPES = (IO.IMAGE,) FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram/v1" + CATEGORY = "api node/image/Ideogram" DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True @@ -322,6 +337,7 @@ class IdeogramV1(ComfyNodeABC): seed=0, negative_prompt="", num_images=1, + unique_id=None, **kwargs, ): # Determine the model based on turbo setting @@ -361,6 +377,7 @@ class IdeogramV1(ComfyNodeABC): if not image_urls: raise Exception("No image URLs were generated in the response") + display_image_urls_on_node(image_urls, unique_id) return (download_and_process_images(image_urls),) @@ -460,12 +477,13 @@ class IdeogramV2(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } RETURN_TYPES = (IO.IMAGE,) FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram/v2" + CATEGORY = "api node/image/Ideogram" DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True @@ -481,6 +499,7 @@ class IdeogramV2(ComfyNodeABC): negative_prompt="", num_images=1, color_palette="", + unique_id=None, **kwargs, ): aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) @@ -534,6 +553,7 @@ class IdeogramV2(ComfyNodeABC): if not image_urls: raise Exception("No image URLs were generated in the response") + display_image_urls_on_node(image_urls, unique_id) return (download_and_process_images(image_urls),) class IdeogramV3(ComfyNodeABC): @@ -623,12 +643,13 @@ class IdeogramV3(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } RETURN_TYPES = (IO.IMAGE,) FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram/v3" + CATEGORY = "api node/image/Ideogram" DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True @@ -643,6 +664,7 @@ class IdeogramV3(ComfyNodeABC): seed=0, num_images=1, rendering_speed="BALANCED", + unique_id=None, **kwargs, ): # Check if both image and mask are provided for editing mode @@ -762,6 +784,7 @@ class IdeogramV3(ComfyNodeABC): if not image_urls: raise Exception("No image URLs were generated in the response") + display_image_urls_on_node(image_urls, unique_id) return (download_and_process_images(image_urls),) @@ -776,4 +799,3 @@ NODE_DISPLAY_NAME_MAPPINGS = { "IdeogramV2": "Ideogram V2", "IdeogramV3": "Ideogram V3", } - diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 2d0fd8883..641cd6353 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -6,6 +6,7 @@ For source of truth on the allowed permutations of request fields, please refere from __future__ import annotations from typing import Optional, TypeVar, Any +from collections.abc import Callable import math import logging @@ -64,6 +65,12 @@ from comfy_api_nodes.apinode_utils import ( download_url_to_image_tensor, ) from comfy_api_nodes.mapper_utils import model_field_to_node_input +from comfy_api_nodes.util.validation_utils import ( + validate_image_dimensions, + validate_image_aspect_ratio, + validate_video_dimensions, + validate_video_duration, +) from comfy_api.input.basic_types import AudioInput from comfy_api.input.video_types import VideoInput from comfy_api.input_impl import VideoFromFile @@ -79,13 +86,20 @@ PATH_CHARACTER_IMAGE = f"/proxy/kling/{KLING_API_VERSION}/images/generations" PATH_VIRTUAL_TRY_ON = f"/proxy/kling/{KLING_API_VERSION}/images/kolors-virtual-try-on" PATH_IMAGE_GENERATIONS = f"/proxy/kling/{KLING_API_VERSION}/images/generations" - MAX_PROMPT_LENGTH_T2V = 2500 MAX_PROMPT_LENGTH_I2V = 500 MAX_PROMPT_LENGTH_IMAGE_GEN = 500 MAX_NEGATIVE_PROMPT_LENGTH_IMAGE_GEN = 200 MAX_PROMPT_LENGTH_LIP_SYNC = 120 +AVERAGE_DURATION_T2V = 319 +AVERAGE_DURATION_I2V = 164 +AVERAGE_DURATION_LIP_SYNC = 455 +AVERAGE_DURATION_VIRTUAL_TRY_ON = 19 +AVERAGE_DURATION_IMAGE_GEN = 32 +AVERAGE_DURATION_VIDEO_EFFECTS = 320 +AVERAGE_DURATION_VIDEO_EXTEND = 320 + R = TypeVar("R") @@ -95,7 +109,13 @@ class KlingApiError(Exception): pass -def poll_until_finished(auth_kwargs: dict[str,str], api_endpoint: ApiEndpoint[Any, R]) -> R: +def poll_until_finished( + auth_kwargs: dict[str, str], + api_endpoint: ApiEndpoint[Any, R], + result_url_extractor: Optional[Callable[[R], str]] = None, + estimated_duration: Optional[int] = None, + node_id: Optional[str] = None, +) -> R: """Polls the Kling API endpoint until the task reaches a terminal state, then returns the response.""" return PollingOperation( poll_endpoint=api_endpoint, @@ -109,6 +129,9 @@ def poll_until_finished(auth_kwargs: dict[str,str], api_endpoint: ApiEndpoint[An else None ), auth_kwargs=auth_kwargs, + result_url_extractor=result_url_extractor, + estimated_duration=estimated_duration, + node_id=node_id, ).execute() @@ -192,23 +215,8 @@ def validate_input_image(image: torch.Tensor) -> None: See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo """ - if len(image.shape) == 4: - height, width = image.shape[1], image.shape[2] - elif len(image.shape) == 3: - height, width = image.shape[0], image.shape[1] - else: - raise ValueError("Invalid image tensor shape.") - - # Ensure minimum resolution is met - if height < 300: - raise ValueError("Image height must be at least 300px") - if width < 300: - raise ValueError("Image width must be at least 300px") - - # Ensure aspect ratio is within acceptable range - aspect_ratio = width / height - if aspect_ratio < 1 / 2.5 or aspect_ratio > 2.5: - raise ValueError("Image aspect ratio must be between 1:2.5 and 2.5:1") + validate_image_dimensions(image, min_width=300, min_height=300) + validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5) def get_camera_control_input_config( @@ -227,7 +235,9 @@ def get_camera_control_input_config( def get_video_from_response(response) -> KlingVideoResult: - """Returns the first video object from the Kling video generation task result.""" + """Returns the first video object from the Kling video generation task result. + Will raise an error if the response is not valid. + """ video = response.data.task_result.videos[0] logging.info( "Kling task %s succeeded. Video URL: %s", response.data.task_id, video.url @@ -235,12 +245,37 @@ def get_video_from_response(response) -> KlingVideoResult: return video +def get_video_url_from_response(response) -> Optional[str]: + """Returns the first video url from the Kling video generation task result. + Will not raise an error if the response is not valid. + """ + if response and is_valid_video_response(response): + return str(get_video_from_response(response).url) + else: + return None + + def get_images_from_response(response) -> list[KlingImageResult]: + """Returns the list of image objects from the Kling image generation task result. + Will raise an error if the response is not valid. + """ images = response.data.task_result.images logging.info("Kling task %s succeeded. Images: %s", response.data.task_id, images) return images +def get_images_urls_from_response(response) -> Optional[str]: + """Returns the list of image urls from the Kling image generation task result. + Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls. + """ + if response and is_valid_image_response(response): + images = get_images_from_response(response) + image_urls = [str(image.url) for image in images] + return "\n".join(image_urls) + else: + return None + + def video_result_to_node_output( video: KlingVideoResult, ) -> tuple[VideoFromFile, str, str]: @@ -312,6 +347,7 @@ class KlingCameraControls(KlingNodeBase): RETURN_TYPES = ("CAMERA_CONTROL",) RETURN_NAMES = ("camera_control",) FUNCTION = "main" + API_NODE = False # This is just a helper node, it doesn't make an API call @classmethod def VALIDATE_INPUTS( @@ -421,6 +457,7 @@ class KlingTextToVideoNode(KlingNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -428,7 +465,9 @@ class KlingTextToVideoNode(KlingNodeBase): RETURN_NAMES = ("VIDEO", "video_id", "duration") DESCRIPTION = "Kling Text to Video Node" - def get_response(self, task_id: str, auth_kwargs: dict[str,str]) -> KlingText2VideoResponse: + def get_response( + self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None + ) -> KlingText2VideoResponse: return poll_until_finished( auth_kwargs, ApiEndpoint( @@ -437,6 +476,9 @@ class KlingTextToVideoNode(KlingNodeBase): request_model=EmptyRequest, response_model=KlingText2VideoResponse, ), + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_T2V, + node_id=node_id, ) def api_call( @@ -449,6 +491,7 @@ class KlingTextToVideoNode(KlingNodeBase): camera_control: Optional[KlingCameraControl] = None, model_name: Optional[str] = None, duration: Optional[str] = None, + unique_id: Optional[str] = None, **kwargs, ) -> tuple[VideoFromFile, str, str]: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) @@ -478,7 +521,9 @@ class KlingTextToVideoNode(KlingNodeBase): validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response(task_id, auth_kwargs=kwargs) + final_response = self.get_response( + task_id, auth_kwargs=kwargs, node_id=unique_id + ) validate_video_result_response(final_response) video = get_video_from_response(final_response) @@ -528,6 +573,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -540,6 +586,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode): cfg_scale: float, aspect_ratio: str, camera_control: Optional[KlingCameraControl] = None, + unique_id: Optional[str] = None, **kwargs, ): return super().api_call( @@ -613,6 +660,7 @@ class KlingImage2VideoNode(KlingNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -620,7 +668,9 @@ class KlingImage2VideoNode(KlingNodeBase): RETURN_NAMES = ("VIDEO", "video_id", "duration") DESCRIPTION = "Kling Image to Video Node" - def get_response(self, task_id: str, auth_kwargs: dict[str,str]) -> KlingImage2VideoResponse: + def get_response( + self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None + ) -> KlingImage2VideoResponse: return poll_until_finished( auth_kwargs, ApiEndpoint( @@ -629,6 +679,9 @@ class KlingImage2VideoNode(KlingNodeBase): request_model=KlingImage2VideoRequest, response_model=KlingImage2VideoResponse, ), + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_I2V, + node_id=node_id, ) def api_call( @@ -643,6 +696,7 @@ class KlingImage2VideoNode(KlingNodeBase): duration: str, camera_control: Optional[KlingCameraControl] = None, end_frame: Optional[torch.Tensor] = None, + unique_id: Optional[str] = None, **kwargs, ) -> tuple[VideoFromFile]: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V) @@ -681,7 +735,9 @@ class KlingImage2VideoNode(KlingNodeBase): validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response(task_id, auth_kwargs=kwargs) + final_response = self.get_response( + task_id, auth_kwargs=kwargs, node_id=unique_id + ) validate_video_result_response(final_response) video = get_video_from_response(final_response) @@ -734,6 +790,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -747,6 +804,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode): cfg_scale: float, aspect_ratio: str, camera_control: KlingCameraControl, + unique_id: Optional[str] = None, **kwargs, ): return super().api_call( @@ -759,6 +817,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode): prompt=prompt, negative_prompt=negative_prompt, camera_control=camera_control, + unique_id=unique_id, **kwargs, ) @@ -830,6 +889,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -844,6 +904,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode): cfg_scale: float, aspect_ratio: str, mode: str, + unique_id: Optional[str] = None, **kwargs, ): mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[ @@ -859,6 +920,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode): aspect_ratio=aspect_ratio, duration=duration, end_frame=end_frame, + unique_id=unique_id, **kwargs, ) @@ -892,6 +954,7 @@ class KlingVideoExtendNode(KlingNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -899,7 +962,9 @@ class KlingVideoExtendNode(KlingNodeBase): RETURN_NAMES = ("VIDEO", "video_id", "duration") DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes." - def get_response(self, task_id: str, auth_kwargs: dict[str,str]) -> KlingVideoExtendResponse: + def get_response( + self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None + ) -> KlingVideoExtendResponse: return poll_until_finished( auth_kwargs, ApiEndpoint( @@ -908,6 +973,9 @@ class KlingVideoExtendNode(KlingNodeBase): request_model=EmptyRequest, response_model=KlingVideoExtendResponse, ), + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND, + node_id=node_id, ) def api_call( @@ -916,6 +984,7 @@ class KlingVideoExtendNode(KlingNodeBase): negative_prompt: str, cfg_scale: float, video_id: str, + unique_id: Optional[str] = None, **kwargs, ) -> tuple[VideoFromFile, str, str]: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) @@ -939,7 +1008,9 @@ class KlingVideoExtendNode(KlingNodeBase): validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response(task_id, auth_kwargs=kwargs) + final_response = self.get_response( + task_id, auth_kwargs=kwargs, node_id=unique_id + ) validate_video_result_response(final_response) video = get_video_from_response(final_response) @@ -952,7 +1023,9 @@ class KlingVideoEffectsBase(KlingNodeBase): RETURN_TYPES = ("VIDEO", "STRING", "STRING") RETURN_NAMES = ("VIDEO", "video_id", "duration") - def get_response(self, task_id: str, auth_kwargs: dict[str,str]) -> KlingVideoEffectsResponse: + def get_response( + self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None + ) -> KlingVideoEffectsResponse: return poll_until_finished( auth_kwargs, ApiEndpoint( @@ -961,6 +1034,9 @@ class KlingVideoEffectsBase(KlingNodeBase): request_model=EmptyRequest, response_model=KlingVideoEffectsResponse, ), + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS, + node_id=node_id, ) def api_call( @@ -972,6 +1048,7 @@ class KlingVideoEffectsBase(KlingNodeBase): image_1: torch.Tensor, image_2: Optional[torch.Tensor] = None, mode: Optional[KlingVideoGenMode] = None, + unique_id: Optional[str] = None, **kwargs, ): if dual_character: @@ -1009,7 +1086,9 @@ class KlingVideoEffectsBase(KlingNodeBase): validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response(task_id, auth_kwargs=kwargs) + final_response = self.get_response( + task_id, auth_kwargs=kwargs, node_id=unique_id + ) validate_video_result_response(final_response) video = get_video_from_response(final_response) @@ -1053,6 +1132,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -1068,6 +1148,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): model_name: KlingCharacterEffectModelName, mode: KlingVideoGenMode, duration: KlingVideoGenDuration, + unique_id: Optional[str] = None, **kwargs, ): video, _, duration = super().api_call( @@ -1078,10 +1159,12 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): duration=duration, image_1=image_left, image_2=image_right, + unique_id=unique_id, **kwargs, ) return video, duration + class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase): """Kling Single Image Video Effect Node""" @@ -1117,6 +1200,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -1128,6 +1212,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase): effect_scene: KlingSingleImageEffectsScene, model_name: KlingSingleImageEffectModelName, duration: KlingVideoGenDuration, + unique_id: Optional[str] = None, **kwargs, ): return super().api_call( @@ -1136,6 +1221,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase): model_name=model_name, duration=duration, image_1=image, + unique_id=unique_id, **kwargs, ) @@ -1146,6 +1232,17 @@ class KlingLipSyncBase(KlingNodeBase): RETURN_TYPES = ("VIDEO", "STRING", "STRING") RETURN_NAMES = ("VIDEO", "video_id", "duration") + def validate_lip_sync_video(self, video: VideoInput): + """ + Validates the input video adheres to the expectations of the Kling Lip Sync API: + - Video length does not exceed 10s and is not shorter than 2s + - Length and width dimensions should both be between 720px and 1920px + + See: https://app.klingai.com/global/dev/document-api/apiReference/model/videoTolip + """ + validate_video_dimensions(video, 720, 1920) + validate_video_duration(video, 2, 10) + def validate_text(self, text: str): if not text: raise ValueError("Text is required") @@ -1154,7 +1251,9 @@ class KlingLipSyncBase(KlingNodeBase): f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters." ) - def get_response(self, task_id: str, auth_kwargs: dict[str,str]) -> KlingLipSyncResponse: + def get_response( + self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None + ) -> KlingLipSyncResponse: """Polls the Kling API endpoint until the task reaches a terminal state.""" return poll_until_finished( auth_kwargs, @@ -1164,6 +1263,9 @@ class KlingLipSyncBase(KlingNodeBase): request_model=EmptyRequest, response_model=KlingLipSyncResponse, ), + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_LIP_SYNC, + node_id=node_id, ) def api_call( @@ -1175,10 +1277,12 @@ class KlingLipSyncBase(KlingNodeBase): text: Optional[str] = None, voice_speed: Optional[float] = None, voice_id: Optional[str] = None, - **kwargs + unique_id: Optional[str] = None, + **kwargs, ) -> tuple[VideoFromFile, str, str]: if text: self.validate_text(text) + self.validate_lip_sync_video(video) # Upload video to Comfy API and get download URL video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs) @@ -1217,7 +1321,9 @@ class KlingLipSyncBase(KlingNodeBase): validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response(task_id, auth_kwargs=kwargs) + final_response = self.get_response( + task_id, auth_kwargs=kwargs, node_id=unique_id + ) validate_video_result_response(final_response) video = get_video_from_response(final_response) @@ -1243,16 +1349,18 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } - DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file." + DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length." def api_call( self, video: VideoInput, audio: AudioInput, voice_language: str, + unique_id: Optional[str] = None, **kwargs, ): return super().api_call( @@ -1260,6 +1368,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): audio=audio, voice_language=voice_language, mode="audio2video", + unique_id=unique_id, **kwargs, ) @@ -1352,10 +1461,11 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } - DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt." + DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length." def api_call( self, @@ -1363,6 +1473,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase): text: str, voice: str, voice_speed: float, + unique_id: Optional[str] = None, **kwargs, ): voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice] @@ -1373,6 +1484,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase): voice_id=voice_id, voice_speed=voice_speed, mode="text2video", + unique_id=unique_id, **kwargs, ) @@ -1413,13 +1525,14 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } - DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human." + DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background." def get_response( - self, task_id: str, auth_kwargs: dict[str,str] = None + self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingVirtualTryOnResponse: return poll_until_finished( auth_kwargs, @@ -1429,6 +1542,9 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase): request_model=EmptyRequest, response_model=KlingVirtualTryOnResponse, ), + result_url_extractor=get_images_urls_from_response, + estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON, + node_id=node_id, ) def api_call( @@ -1436,6 +1552,7 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase): human_image: torch.Tensor, cloth_image: torch.Tensor, model_name: KlingVirtualTryOnModelName, + unique_id: Optional[str] = None, **kwargs, ): initial_operation = SynchronousOperation( @@ -1457,7 +1574,9 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase): validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response(task_id, auth_kwargs=kwargs) + final_response = self.get_response( + task_id, auth_kwargs=kwargs, node_id=unique_id + ) validate_image_result_response(final_response) images = get_images_from_response(final_response) @@ -1528,13 +1647,17 @@ class KlingImageGenerationNode(KlingImageGenerationBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image." def get_response( - self, task_id: str, auth_kwargs: Optional[dict[str,str]] = None + self, + task_id: str, + auth_kwargs: Optional[dict[str, str]], + node_id: Optional[str] = None, ) -> KlingImageGenerationsResponse: return poll_until_finished( auth_kwargs, @@ -1544,6 +1667,9 @@ class KlingImageGenerationNode(KlingImageGenerationBase): request_model=EmptyRequest, response_model=KlingImageGenerationsResponse, ), + result_url_extractor=get_images_urls_from_response, + estimated_duration=AVERAGE_DURATION_IMAGE_GEN, + node_id=node_id, ) def api_call( @@ -1557,6 +1683,7 @@ class KlingImageGenerationNode(KlingImageGenerationBase): n: int, aspect_ratio: KlingImageGenAspectRatio, image: Optional[torch.Tensor] = None, + unique_id: Optional[str] = None, **kwargs, ): self.validate_prompt(prompt, negative_prompt) @@ -1589,7 +1716,9 @@ class KlingImageGenerationNode(KlingImageGenerationBase): validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response(task_id, auth_kwargs=kwargs) + final_response = self.get_response( + task_id, auth_kwargs=kwargs, node_id=unique_id + ) validate_image_result_response(final_response) images = get_images_from_response(final_response) diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index bd33a53e0..525dc38e6 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -36,11 +36,20 @@ from comfy_api_nodes.apinode_utils import ( process_image_response, validate_string, ) +from server import PromptServer import requests import torch from io import BytesIO +LUMA_T2V_AVERAGE_DURATION = 105 +LUMA_I2V_AVERAGE_DURATION = 100 + +def image_result_url_extractor(response: LumaGeneration): + return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None + +def video_result_url_extractor(response: LumaGeneration): + return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None class LumaReferenceNode(ComfyNodeABC): """ @@ -204,6 +213,7 @@ class LumaImageGenerationNode(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -217,6 +227,7 @@ class LumaImageGenerationNode(ComfyNodeABC): image_luma_ref: LumaReferenceChain = None, style_image: torch.Tensor = None, character_image: torch.Tensor = None, + unique_id: str = None, **kwargs, ): validate_string(prompt, strip_whitespace=True, min_length=3) @@ -271,6 +282,8 @@ class LumaImageGenerationNode(ComfyNodeABC): completed_statuses=[LumaState.completed], failed_statuses=[LumaState.failed], status_extractor=lambda x: x.state, + result_url_extractor=image_result_url_extractor, + node_id=unique_id, auth_kwargs=kwargs, ) response_poll = operation.execute() @@ -353,6 +366,7 @@ class LumaImageModifyNode(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -363,6 +377,7 @@ class LumaImageModifyNode(ComfyNodeABC): image: torch.Tensor, image_weight: float, seed, + unique_id: str = None, **kwargs, ): # first, upload image @@ -399,6 +414,8 @@ class LumaImageModifyNode(ComfyNodeABC): completed_statuses=[LumaState.completed], failed_statuses=[LumaState.failed], status_extractor=lambda x: x.state, + result_url_extractor=image_result_url_extractor, + node_id=unique_id, auth_kwargs=kwargs, ) response_poll = operation.execute() @@ -473,6 +490,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -486,6 +504,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): loop: bool, seed, luma_concepts: LumaConceptChain = None, + unique_id: str = None, **kwargs, ): validate_string(prompt, strip_whitespace=False, min_length=3) @@ -512,6 +531,9 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): ) response_api: LumaGeneration = operation.execute() + if unique_id: + PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) + operation = PollingOperation( poll_endpoint=ApiEndpoint( path=f"/proxy/luma/generations/{response_api.id}", @@ -522,6 +544,9 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): completed_statuses=[LumaState.completed], failed_statuses=[LumaState.failed], status_extractor=lambda x: x.state, + result_url_extractor=video_result_url_extractor, + node_id=unique_id, + estimated_duration=LUMA_T2V_AVERAGE_DURATION, auth_kwargs=kwargs, ) response_poll = operation.execute() @@ -597,6 +622,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -611,6 +637,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): first_image: torch.Tensor = None, last_image: torch.Tensor = None, luma_concepts: LumaConceptChain = None, + unique_id: str = None, **kwargs, ): if first_image is None and last_image is None: @@ -642,6 +669,9 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): ) response_api: LumaGeneration = operation.execute() + if unique_id: + PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) + operation = PollingOperation( poll_endpoint=ApiEndpoint( path=f"/proxy/luma/generations/{response_api.id}", @@ -652,6 +682,9 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): completed_statuses=[LumaState.completed], failed_statuses=[LumaState.failed], status_extractor=lambda x: x.state, + result_url_extractor=video_result_url_extractor, + node_id=unique_id, + estimated_duration=LUMA_I2V_AVERAGE_DURATION, auth_kwargs=kwargs, ) response_poll = operation.execute() diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index fd64aeb0b..9b46636db 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -1,3 +1,7 @@ +from typing import Union +import logging +import torch + from comfy.comfy_types.node_typing import IO from comfy_api.input_impl.video_types import VideoFromFile from comfy_api_nodes.apis import ( @@ -20,16 +24,19 @@ from comfy_api_nodes.apinode_utils import ( upload_images_to_comfyapi, validate_string, ) +from server import PromptServer -import torch -import logging +I2V_AVERAGE_DURATION = 114 +T2V_AVERAGE_DURATION = 234 class MinimaxTextToVideoNode: """ Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API. """ + AVERAGE_DURATION = T2V_AVERAGE_DURATION + @classmethod def INPUT_TYPES(s): return { @@ -68,6 +75,7 @@ class MinimaxTextToVideoNode: "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -85,6 +93,7 @@ class MinimaxTextToVideoNode: model="T2V-01", image: torch.Tensor=None, # used for ImageToVideo subject: torch.Tensor=None, # used for SubjectToVideo + unique_id: Union[str, None]=None, **kwargs, ): ''' @@ -138,6 +147,8 @@ class MinimaxTextToVideoNode: completed_statuses=["Success"], failed_statuses=["Fail"], status_extractor=lambda x: x.status.value, + estimated_duration=self.AVERAGE_DURATION, + node_id=unique_id, auth_kwargs=kwargs, ) task_result = video_generate_operation.execute() @@ -164,6 +175,12 @@ class MinimaxTextToVideoNode: f"No video was found in the response. Full response: {file_result.model_dump()}" ) logging.info(f"Generated video URL: {file_url}") + if unique_id: + if hasattr(file_result.file, "backup_download_url"): + message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" + else: + message = f"Result URL: {file_url}" + PromptServer.instance.send_progress_text(message, unique_id) video_io = download_url_to_bytesio(file_url) if video_io is None: @@ -178,6 +195,8 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode): Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. """ + AVERAGE_DURATION = I2V_AVERAGE_DURATION + @classmethod def INPUT_TYPES(s): return { @@ -223,6 +242,7 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -239,6 +259,8 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode): Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. """ + AVERAGE_DURATION = T2V_AVERAGE_DURATION + @classmethod def INPUT_TYPES(s): return { @@ -282,6 +304,7 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index c63908be2..be1d2de4a 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -1,29 +1,86 @@ import io +from typing import TypedDict, Optional +import json +import os +import time +import re +import uuid +from enum import Enum from inspect import cleandoc import numpy as np import torch from PIL import Image - from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict +from server import PromptServer +import folder_paths from comfy_api_nodes.apis import ( OpenAIImageGenerationRequest, OpenAIImageEditRequest, OpenAIImageGenerationResponse, + OpenAICreateResponse, + OpenAIResponse, + CreateModelResponseProperties, + Item, + Includable, + OutputContent, + InputImageContent, + Detail, + InputTextContent, + InputMessage, + InputMessageContentList, + InputContent, + InputFileContent, ) from comfy_api_nodes.apis.client import ( ApiEndpoint, HttpMethod, SynchronousOperation, + PollingOperation, + EmptyRequest, ) from comfy_api_nodes.apinode_utils import ( downscale_image_tensor, validate_and_cast_response, validate_string, + tensor_to_base64_string, + text_filepath_to_data_uri, ) +from comfy_api_nodes.mapper_utils import model_field_to_node_input + + +RESPONSES_ENDPOINT = "/proxy/openai/v1/responses" +STARTING_POINT_ID_PATTERN = r"" + + +class HistoryEntry(TypedDict): + """Type definition for a single history entry in the chat.""" + + prompt: str + response: str + response_id: str + timestamp: float + + +class ChatHistory(TypedDict): + """Type definition for the chat history dictionary.""" + + __annotations__: dict[str, list[HistoryEntry]] + + +class SupportedOpenAIModel(str, Enum): + o4_mini = "o4-mini" + o1 = "o1" + o3 = "o3" + o1_pro = "o1-pro" + gpt_4o = "gpt-4o" + gpt_4_1 = "gpt-4.1" + gpt_4_1_mini = "gpt-4.1-mini" + gpt_4_1_nano = "gpt-4.1-nano" + class OpenAIDalle2(ComfyNodeABC): """ @@ -96,6 +153,7 @@ class OpenAIDalle2(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -113,7 +171,8 @@ class OpenAIDalle2(ComfyNodeABC): mask=None, n=1, size="1024x1024", - **kwargs + unique_id=None, + **kwargs, ): validate_string(prompt, strip_whitespace=False) model = "dall-e-2" @@ -176,7 +235,7 @@ class OpenAIDalle2(ComfyNodeABC): response = operation.execute() - img_tensor = validate_and_cast_response(response) + img_tensor = validate_and_cast_response(response, node_id=unique_id) return (img_tensor,) @@ -242,6 +301,7 @@ class OpenAIDalle3(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -258,7 +318,8 @@ class OpenAIDalle3(ComfyNodeABC): style="natural", quality="standard", size="1024x1024", - **kwargs + unique_id=None, + **kwargs, ): validate_string(prompt, strip_whitespace=False) model = "dall-e-3" @@ -284,7 +345,7 @@ class OpenAIDalle3(ComfyNodeABC): response = operation.execute() - img_tensor = validate_and_cast_response(response) + img_tensor = validate_and_cast_response(response, node_id=unique_id) return (img_tensor,) @@ -375,6 +436,7 @@ class OpenAIGPTImage1(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -394,12 +456,13 @@ class OpenAIGPTImage1(ComfyNodeABC): mask=None, n=1, size="1024x1024", - **kwargs + unique_id=None, + **kwargs, ): validate_string(prompt, strip_whitespace=False) model = "gpt-image-1" path = "/proxy/openai/images/generations" - content_type="application/json" + content_type = "application/json" request_class = OpenAIImageGenerationRequest img_binaries = [] mask_binary = None @@ -408,7 +471,7 @@ class OpenAIGPTImage1(ComfyNodeABC): if image is not None: path = "/proxy/openai/images/edits" request_class = OpenAIImageEditRequest - content_type ="multipart/form-data" + content_type = "multipart/form-data" batch_size = image.shape[0] @@ -476,21 +539,470 @@ class OpenAIGPTImage1(ComfyNodeABC): response = operation.execute() - img_tensor = validate_and_cast_response(response) + img_tensor = validate_and_cast_response(response, node_id=unique_id) return (img_tensor,) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique +class OpenAITextNode(ComfyNodeABC): + """ + Base class for OpenAI text generation nodes. + """ + + RETURN_TYPES = (IO.STRING,) + FUNCTION = "api_call" + CATEGORY = "api node/text/OpenAI" + API_NODE = True + + +class OpenAIChatNode(OpenAITextNode): + """ + Node to generate text responses from an OpenAI model. + """ + + def __init__(self) -> None: + """Initialize the chat node with a new session ID and empty history.""" + self.current_session_id: str = str(uuid.uuid4()) + self.history: dict[str, list[HistoryEntry]] = {} + self.previous_response_id: Optional[str] = None + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Text inputs to the model, used to generate a response.", + }, + ), + "persist_context": ( + IO.BOOLEAN, + { + "default": True, + "tooltip": "Persist chat context between calls (multi-turn conversation)", + }, + ), + "model": model_field_to_node_input( + IO.COMBO, + OpenAICreateResponse, + "model", + enum_type=SupportedOpenAIModel, + ), + }, + "optional": { + "images": ( + IO.IMAGE, + { + "default": None, + "tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.", + }, + ), + "files": ( + "OPENAI_INPUT_FILES", + { + "default": None, + "tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the OpenAI Chat Input Files node.", + }, + ), + "advanced_options": ( + "OPENAI_CHAT_CONFIG", + { + "default": None, + "tooltip": "Optional configuration for the model. Accepts inputs from the OpenAI Chat Advanced Options node.", + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + DESCRIPTION = "Generate text responses from an OpenAI model." + + def get_result_response( + self, + response_id: str, + include: Optional[list[Includable]] = None, + auth_kwargs: Optional[dict[str, str]] = None, + ) -> OpenAIResponse: + """ + Retrieve a model response with the given ID from the OpenAI API. + + Args: + response_id (str): The ID of the response to retrieve. + include (Optional[List[Includable]]): Additional fields to include + in the response. See the `include` parameter for Response + creation above for more information. + + """ + return PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"{RESPONSES_ENDPOINT}/{response_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=OpenAIResponse, + query_params={"include": include}, + ), + completed_statuses=["completed"], + failed_statuses=["failed"], + status_extractor=lambda response: response.status, + auth_kwargs=auth_kwargs, + ).execute() + + def get_message_content_from_response( + self, response: OpenAIResponse + ) -> list[OutputContent]: + """Extract message content from the API response.""" + for output in response.output: + if output.root.type == "message": + return output.root.content + raise TypeError("No output message found in response") + + def get_text_from_message_content( + self, message_content: list[OutputContent] + ) -> str: + """Extract text content from message content.""" + for content_item in message_content: + if content_item.root.type == "output_text": + return str(content_item.root.text) + return "No text output found in response" + + def get_history_text(self, session_id: str) -> str: + """Convert the entire history for a given session to JSON string.""" + return json.dumps(self.history[session_id]) + + def display_history_on_node(self, session_id: str, node_id: str) -> None: + """Display formatted chat history on the node UI.""" + render_spec = { + "node_id": node_id, + "component": "ChatHistoryWidget", + "props": { + "history": self.get_history_text(session_id), + }, + } + PromptServer.instance.send_sync( + "display_component", + render_spec, + ) + + def add_to_history( + self, session_id: str, prompt: str, output_text: str, response_id: str + ) -> None: + """Add a new entry to the chat history.""" + if session_id not in self.history: + self.history[session_id] = [] + self.history[session_id].append( + { + "prompt": prompt, + "response": output_text, + "response_id": response_id, + "timestamp": time.time(), + } + ) + + def parse_output_text_from_response(self, response: OpenAIResponse) -> str: + """Extract text output from the API response.""" + message_contents = self.get_message_content_from_response(response) + return self.get_text_from_message_content(message_contents) + + def generate_new_session_id(self) -> str: + """Generate a new unique session ID.""" + return str(uuid.uuid4()) + + def get_session_id(self, persist_context: bool) -> str: + """Get the current or generate a new session ID based on context persistence.""" + return ( + self.current_session_id + if persist_context + else self.generate_new_session_id() + ) + + def tensor_to_input_image_content( + self, image: torch.Tensor, detail_level: Detail = "auto" + ) -> InputImageContent: + """Convert a tensor to an input image content object.""" + return InputImageContent( + detail=detail_level, + image_url=f"data:image/png;base64,{tensor_to_base64_string(image)}", + type="input_image", + ) + + def create_input_message_contents( + self, + prompt: str, + image: Optional[torch.Tensor] = None, + files: Optional[list[InputFileContent]] = None, + ) -> InputMessageContentList: + """Create a list of input message contents from prompt and optional image.""" + content_list: list[InputContent] = [ + InputTextContent(text=prompt, type="input_text"), + ] + if image is not None: + for i in range(image.shape[0]): + content_list.append( + self.tensor_to_input_image_content(image[i].unsqueeze(0)) + ) + if files is not None: + content_list.extend(files) + + return InputMessageContentList( + root=content_list, + ) + + def parse_response_id_from_prompt(self, prompt: str) -> Optional[str]: + """Extract response ID from prompt if it exists.""" + parsed_id = re.search(STARTING_POINT_ID_PATTERN, prompt) + return parsed_id.group(1) if parsed_id else None + + def strip_response_tag_from_prompt(self, prompt: str) -> str: + """Remove the response ID tag from the prompt.""" + return re.sub(STARTING_POINT_ID_PATTERN, "", prompt.strip()) + + def delete_history_after_response_id( + self, new_start_id: str, session_id: str + ) -> None: + """Delete history entries after a specific response ID.""" + if session_id not in self.history: + return + + new_history = [] + i = 0 + while ( + i < len(self.history[session_id]) + and self.history[session_id][i]["response_id"] != new_start_id + ): + new_history.append(self.history[session_id][i]) + i += 1 + + # Since it's the new starting point (not the response being edited), we include it as well + if i < len(self.history[session_id]): + new_history.append(self.history[session_id][i]) + + self.history[session_id] = new_history + + def api_call( + self, + prompt: str, + persist_context: bool, + model: SupportedOpenAIModel, + unique_id: Optional[str] = None, + images: Optional[torch.Tensor] = None, + files: Optional[list[InputFileContent]] = None, + advanced_options: Optional[CreateModelResponseProperties] = None, + **kwargs, + ) -> tuple[str]: + # Validate inputs + validate_string(prompt, strip_whitespace=False) + + session_id = self.get_session_id(persist_context) + response_id_override = self.parse_response_id_from_prompt(prompt) + if response_id_override: + is_starting_from_beginning = response_id_override == "start" + if is_starting_from_beginning: + self.history[session_id] = [] + previous_response_id = None + else: + previous_response_id = response_id_override + self.delete_history_after_response_id(response_id_override, session_id) + prompt = self.strip_response_tag_from_prompt(prompt) + elif persist_context: + previous_response_id = self.previous_response_id + else: + previous_response_id = None + + # Create response + create_response = SynchronousOperation( + endpoint=ApiEndpoint( + path=RESPONSES_ENDPOINT, + method=HttpMethod.POST, + request_model=OpenAICreateResponse, + response_model=OpenAIResponse, + ), + request=OpenAICreateResponse( + input=[ + Item( + root=InputMessage( + content=self.create_input_message_contents( + prompt, images, files + ), + role="user", + ) + ), + ], + store=True, + stream=False, + model=model, + previous_response_id=previous_response_id, + **( + advanced_options.model_dump(exclude_none=True) + if advanced_options + else {} + ), + ), + auth_kwargs=kwargs, + ).execute() + response_id = create_response.id + + # Get result output + result_response = self.get_result_response(response_id, auth_kwargs=kwargs) + output_text = self.parse_output_text_from_response(result_response) + + # Update history + self.add_to_history(session_id, prompt, output_text, response_id) + self.display_history_on_node(session_id, unique_id) + self.previous_response_id = response_id + + return (output_text,) + + +class OpenAIInputFiles(ComfyNodeABC): + """ + Loads and formats input files for OpenAI API. + """ + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + """ + For details about the supported file input types, see: + https://platform.openai.com/docs/guides/pdf-files?api-mode=responses + """ + input_dir = folder_paths.get_input_directory() + input_files = [ + f + for f in os.scandir(input_dir) + if f.is_file() + and (f.name.endswith(".txt") or f.name.endswith(".pdf")) + and f.stat().st_size < 32 * 1024 * 1024 + ] + input_files = sorted(input_files, key=lambda x: x.name) + input_files = [f.name for f in input_files] + return { + "required": { + "file": ( + IO.COMBO, + { + "tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.", + "options": input_files, + "default": input_files[0] if input_files else None, + }, + ), + }, + "optional": { + "OPENAI_INPUT_FILES": ( + "OPENAI_INPUT_FILES", + { + "tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.", + "default": None, + }, + ), + }, + } + + DESCRIPTION = "Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes." + RETURN_TYPES = ("OPENAI_INPUT_FILES",) + FUNCTION = "prepare_files" + CATEGORY = "api node/text/OpenAI" + + def create_input_file_content(self, file_path: str) -> InputFileContent: + return InputFileContent( + file_data=text_filepath_to_data_uri(file_path), + filename=os.path.basename(file_path), + type="input_file", + ) + + def prepare_files( + self, file: str, OPENAI_INPUT_FILES: list[InputFileContent] = [] + ) -> tuple[list[InputFileContent]]: + """ + Loads and formats input files for OpenAI API. + """ + file_path = folder_paths.get_annotated_filepath(file) + input_file_content = self.create_input_file_content(file_path) + files = [input_file_content] + OPENAI_INPUT_FILES + return (files,) + + +class OpenAIChatConfig(ComfyNodeABC): + """Allows setting additional configuration for the OpenAI Chat Node.""" + + RETURN_TYPES = ("OPENAI_CHAT_CONFIG",) + FUNCTION = "configure" + DESCRIPTION = ( + "Allows specifying advanced configuration options for the OpenAI Chat Nodes." + ) + CATEGORY = "api node/text/OpenAI" + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "truncation": ( + IO.COMBO, + { + "options": ["auto", "disabled"], + "default": "auto", + "tooltip": "The truncation strategy to use for the model response. auto: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation.disabled: If a model response will exceed the context window size for a model, the request will fail with a 400 error", + }, + ), + }, + "optional": { + "max_output_tokens": model_field_to_node_input( + IO.INT, + OpenAICreateResponse, + "max_output_tokens", + min=16, + default=4096, + max=16384, + tooltip="An upper bound for the number of tokens that can be generated for a response, including visible output tokens", + ), + "instructions": model_field_to_node_input( + IO.STRING, OpenAICreateResponse, "instructions", multiline=True + ), + }, + } + + def configure( + self, + truncation: bool, + instructions: Optional[str] = None, + max_output_tokens: Optional[int] = None, + ) -> tuple[CreateModelResponseProperties]: + """ + Configure advanced options for the OpenAI Chat Node. + + Note: + While `top_p` and `temperature` are listed as properties in the + spec, they are not supported for all models (e.g., o4-mini). + They are not exposed as inputs at all to avoid having to manually + remove depending on model choice. + """ + return ( + CreateModelResponseProperties( + instructions=instructions, + truncation=truncation, + max_output_tokens=max_output_tokens, + ), + ) + + NODE_CLASS_MAPPINGS = { "OpenAIDalle2": OpenAIDalle2, "OpenAIDalle3": OpenAIDalle3, "OpenAIGPTImage1": OpenAIGPTImage1, + "OpenAIChatNode": OpenAIChatNode, + "OpenAIInputFiles": OpenAIInputFiles, + "OpenAIChatConfig": OpenAIChatConfig, } -# A dictionary that contains the friendly/humanly readable titles for the nodes NODE_DISPLAY_NAME_MAPPINGS = { "OpenAIDalle2": "OpenAI DALL·E 2", "OpenAIDalle3": "OpenAI DALL·E 3", "OpenAIGPTImage1": "OpenAI GPT Image 1", + "OpenAIChatNode": "OpenAI Chat", + "OpenAIInputFiles": "OpenAI Chat Input Files", + "OpenAIChatConfig": "OpenAI Chat Advanced Options", } diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py index 08ec9cf07..1cc708564 100644 --- a/comfy_api_nodes/nodes_pika.py +++ b/comfy_api_nodes/nodes_pika.py @@ -6,40 +6,42 @@ Pika API docs: https://pika-827374fb.mintlify.app/api-reference from __future__ import annotations import io -from typing import Optional, TypeVar import logging -import torch +from typing import Optional, TypeVar + import numpy as np +import torch + +from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions +from comfy_api.input_impl import VideoFromFile +from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput +from comfy_api_nodes.apinode_utils import ( + download_url_to_video_output, + tensor_to_bytesio, +) from comfy_api_nodes.apis import ( - PikaBodyGenerate22T2vGenerate22T2vPost, - PikaGenerateResponse, - PikaBodyGenerate22I2vGenerate22I2vPost, - PikaVideoResponse, - PikaBodyGenerate22C2vGenerate22PikascenesPost, IngredientsMode, - PikaDurationEnum, - PikaResolutionEnum, - PikaBodyGeneratePikaffectsGeneratePikaffectsPost, - PikaBodyGeneratePikadditionsGeneratePikadditionsPost, - PikaBodyGeneratePikaswapsGeneratePikaswapsPost, + PikaBodyGenerate22C2vGenerate22PikascenesPost, + PikaBodyGenerate22I2vGenerate22I2vPost, PikaBodyGenerate22KeyframeGenerate22PikaframesPost, + PikaBodyGenerate22T2vGenerate22T2vPost, + PikaBodyGeneratePikadditionsGeneratePikadditionsPost, + PikaBodyGeneratePikaffectsGeneratePikaffectsPost, + PikaBodyGeneratePikaswapsGeneratePikaswapsPost, + PikaDurationEnum, Pikaffect, + PikaGenerateResponse, + PikaResolutionEnum, + PikaVideoResponse, ) from comfy_api_nodes.apis.client import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( - tensor_to_bytesio, - download_url_to_video_output, + HttpMethod, + PollingOperation, + SynchronousOperation, ) from comfy_api_nodes.mapper_utils import model_field_to_node_input -from comfy_api.input_impl.video_types import VideoInput, VideoContainer, VideoCodec -from comfy_api.input_impl import VideoFromFile -from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions R = TypeVar("R") @@ -121,7 +123,10 @@ class PikaNodeBase(ComfyNodeABC): RETURN_TYPES = ("VIDEO",) def poll_for_task_status( - self, task_id: str, auth_kwargs: Optional[dict[str,str]] = None + self, + task_id: str, + auth_kwargs: Optional[dict[str, str]] = None, + node_id: Optional[str] = None, ) -> PikaGenerateResponse: polling_operation = PollingOperation( poll_endpoint=ApiEndpoint( @@ -141,13 +146,19 @@ class PikaNodeBase(ComfyNodeABC): response.progress if hasattr(response, "progress") else None ), auth_kwargs=auth_kwargs, + result_url_extractor=lambda response: ( + response.url if hasattr(response, "url") else None + ), + node_id=node_id, + estimated_duration=60 ) return polling_operation.execute() def execute_task( self, initial_operation: SynchronousOperation[R, PikaGenerateResponse], - auth_kwargs: Optional[dict[str,str]] = None, + auth_kwargs: Optional[dict[str, str]] = None, + node_id: Optional[str] = None, ) -> tuple[VideoFromFile]: """Executes the initial operation then polls for the task status until it is completed. @@ -195,6 +206,7 @@ class PikaImageToVideoV2_2(PikaNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -208,7 +220,8 @@ class PikaImageToVideoV2_2(PikaNodeBase): seed: int, resolution: str, duration: int, - **kwargs + unique_id: str, + **kwargs, ) -> tuple[VideoFromFile]: # Convert image to BytesIO image_bytes_io = tensor_to_bytesio(image) @@ -238,7 +251,7 @@ class PikaImageToVideoV2_2(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs) + return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaTextToVideoNodeV2_2(PikaNodeBase): @@ -262,6 +275,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -275,6 +289,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase): resolution: str, duration: int, aspect_ratio: float, + unique_id: str, **kwargs, ) -> tuple[VideoFromFile]: initial_operation = SynchronousOperation( @@ -296,7 +311,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase): content_type="application/x-www-form-urlencoded", ) - return self.execute_task(initial_operation, auth_kwargs=kwargs) + return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaScenesV2_2(PikaNodeBase): @@ -340,6 +355,7 @@ class PikaScenesV2_2(PikaNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -354,6 +370,7 @@ class PikaScenesV2_2(PikaNodeBase): duration: int, ingredients_mode: str, aspect_ratio: float, + unique_id: str, image_ingredient_1: Optional[torch.Tensor] = None, image_ingredient_2: Optional[torch.Tensor] = None, image_ingredient_3: Optional[torch.Tensor] = None, @@ -403,7 +420,7 @@ class PikaScenesV2_2(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs) + return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikAdditionsNode(PikaNodeBase): @@ -439,10 +456,11 @@ class PikAdditionsNode(PikaNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } - DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you’d like to add to create a seamlessly integrated result." + DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result." def api_call( self, @@ -451,6 +469,7 @@ class PikAdditionsNode(PikaNodeBase): prompt_text: str, negative_prompt: str, seed: int, + unique_id: str, **kwargs, ) -> tuple[VideoFromFile]: # Convert video to BytesIO @@ -487,7 +506,7 @@ class PikAdditionsNode(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs) + return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaSwapsNode(PikaNodeBase): @@ -532,6 +551,7 @@ class PikaSwapsNode(PikaNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -546,6 +566,7 @@ class PikaSwapsNode(PikaNodeBase): prompt_text: str, negative_prompt: str, seed: int, + unique_id: str, **kwargs, ) -> tuple[VideoFromFile]: # Convert video to BytesIO @@ -592,7 +613,7 @@ class PikaSwapsNode(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs) + return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaffectsNode(PikaNodeBase): @@ -637,6 +658,7 @@ class PikaffectsNode(PikaNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -649,6 +671,7 @@ class PikaffectsNode(PikaNodeBase): prompt_text: str, negative_prompt: str, seed: int, + unique_id: str, **kwargs, ) -> tuple[VideoFromFile]: @@ -670,7 +693,7 @@ class PikaffectsNode(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs) + return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaStartEndFrameNode2_2(PikaNodeBase): @@ -689,6 +712,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -703,6 +727,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase): seed: int, resolution: str, duration: int, + unique_id: str, **kwargs, ) -> tuple[VideoFromFile]: @@ -733,7 +758,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs) + return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 0c29e77c2..ef4a9a802 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -1,5 +1,5 @@ from inspect import cleandoc - +from typing import Optional from comfy_api_nodes.apis.pixverse_api import ( PixverseTextVideoRequest, PixverseImageVideoRequest, @@ -34,11 +34,22 @@ import requests from io import BytesIO +AVERAGE_DURATION_T2V = 32 +AVERAGE_DURATION_I2V = 30 +AVERAGE_DURATION_T2T = 52 + + +def get_video_url_from_response( + response: PixverseGenerationStatusResponse, +) -> Optional[str]: + if response.Resp is None or response.Resp.url is None: + return None + return str(response.Resp.url) + + def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): # first, upload image to Pixverse and get image id to use in actual generation call - files = { - "image": tensor_to_bytesio(image) - } + files = {"image": tensor_to_bytesio(image)} operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/pixverse/image/upload", @@ -54,7 +65,9 @@ def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): response_upload: PixverseImageUploadResponse = operation.execute() if response_upload.Resp is None: - raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'") + raise Exception( + f"PixVerse image upload request failed: '{response_upload.ErrMsg}'" + ) return response_upload.Resp.img_id @@ -73,7 +86,7 @@ class PixverseTemplateNode: def INPUT_TYPES(s): return { "required": { - "template": (list(pixverse_templates.keys()), ), + "template": (list(pixverse_templates.keys()),), } } @@ -87,7 +100,7 @@ class PixverseTemplateNode: class PixverseTextToVideoNode(ComfyNodeABC): """ - Generates videos synchronously based on prompt and output_size. + Generates videos based on prompt and output_size. """ RETURN_TYPES = (IO.VIDEO,) @@ -108,9 +121,7 @@ class PixverseTextToVideoNode(ComfyNodeABC): "tooltip": "Prompt for the video generation", }, ), - "aspect_ratio": ( - [ratio.value for ratio in PixverseAspectRatio], - ), + "aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],), "quality": ( [resolution.value for resolution in PixverseQuality], { @@ -143,12 +154,13 @@ class PixverseTextToVideoNode(ComfyNodeABC): PixverseIO.TEMPLATE, { "tooltip": "An optional template to influence style of generation, created by the PixVerse Template node." - } - ) + }, + ), }, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -160,8 +172,9 @@ class PixverseTextToVideoNode(ComfyNodeABC): duration_seconds: int, motion_mode: str, seed, - negative_prompt: str=None, - pixverse_template: int=None, + negative_prompt: str = None, + pixverse_template: int = None, + unique_id: Optional[str] = None, **kwargs, ): validate_string(prompt, strip_whitespace=False) @@ -205,19 +218,27 @@ class PixverseTextToVideoNode(ComfyNodeABC): response_model=PixverseGenerationStatusResponse, ), completed_statuses=[PixverseStatus.successful], - failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted], + failed_statuses=[ + PixverseStatus.contents_moderation, + PixverseStatus.failed, + PixverseStatus.deleted, + ], status_extractor=lambda x: x.Resp.status, auth_kwargs=kwargs, + node_id=unique_id, + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_T2V, ) response_poll = operation.execute() vid_response = requests.get(response_poll.Resp.url) + return (VideoFromFile(BytesIO(vid_response.content)),) class PixverseImageToVideoNode(ComfyNodeABC): """ - Generates videos synchronously based on prompt and output_size. + Generates videos based on prompt and output_size. """ RETURN_TYPES = (IO.VIDEO,) @@ -230,9 +251,7 @@ class PixverseImageToVideoNode(ComfyNodeABC): def INPUT_TYPES(s): return { "required": { - "image": ( - IO.IMAGE, - ), + "image": (IO.IMAGE,), "prompt": ( IO.STRING, { @@ -273,12 +292,13 @@ class PixverseImageToVideoNode(ComfyNodeABC): PixverseIO.TEMPLATE, { "tooltip": "An optional template to influence style of generation, created by the PixVerse Template node." - } - ) + }, + ), }, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -290,8 +310,9 @@ class PixverseImageToVideoNode(ComfyNodeABC): duration_seconds: int, motion_mode: str, seed, - negative_prompt: str=None, - pixverse_template: int=None, + negative_prompt: str = None, + pixverse_template: int = None, + unique_id: Optional[str] = None, **kwargs, ): validate_string(prompt, strip_whitespace=False) @@ -337,9 +358,16 @@ class PixverseImageToVideoNode(ComfyNodeABC): response_model=PixverseGenerationStatusResponse, ), completed_statuses=[PixverseStatus.successful], - failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted], + failed_statuses=[ + PixverseStatus.contents_moderation, + PixverseStatus.failed, + PixverseStatus.deleted, + ], status_extractor=lambda x: x.Resp.status, auth_kwargs=kwargs, + node_id=unique_id, + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_I2V, ) response_poll = operation.execute() @@ -349,7 +377,7 @@ class PixverseImageToVideoNode(ComfyNodeABC): class PixverseTransitionVideoNode(ComfyNodeABC): """ - Generates videos synchronously based on prompt and output_size. + Generates videos based on prompt and output_size. """ RETURN_TYPES = (IO.VIDEO,) @@ -362,12 +390,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC): def INPUT_TYPES(s): return { "required": { - "first_frame": ( - IO.IMAGE, - ), - "last_frame": ( - IO.IMAGE, - ), + "first_frame": (IO.IMAGE,), + "last_frame": (IO.IMAGE,), "prompt": ( IO.STRING, { @@ -408,6 +432,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -420,7 +445,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC): duration_seconds: int, motion_mode: str, seed, - negative_prompt: str=None, + negative_prompt: str = None, + unique_id: Optional[str] = None, **kwargs, ): validate_string(prompt, strip_whitespace=False) @@ -467,9 +493,16 @@ class PixverseTransitionVideoNode(ComfyNodeABC): response_model=PixverseGenerationStatusResponse, ), completed_statuses=[PixverseStatus.successful], - failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted], + failed_statuses=[ + PixverseStatus.contents_moderation, + PixverseStatus.failed, + PixverseStatus.deleted, + ], status_extractor=lambda x: x.Resp.status, auth_kwargs=kwargs, + node_id=unique_id, + result_url_extractor=get_video_url_from_response, + estimated_duration=AVERAGE_DURATION_T2V, ) response_poll = operation.execute() diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 767d93e3c..e369c4b7e 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -1,5 +1,6 @@ from __future__ import annotations from inspect import cleandoc +from typing import Optional from comfy.utils import ProgressBar from comfy_extras.nodes_images import SVG # Added from comfy.comfy_types.node_typing import IO @@ -29,6 +30,8 @@ from comfy_api_nodes.apinode_utils import ( resize_mask_to_image, validate_string, ) +from server import PromptServer + import torch from io import BytesIO from PIL import UnidentifiedImageError @@ -388,6 +391,7 @@ class RecraftTextToImageNode: "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -400,6 +404,7 @@ class RecraftTextToImageNode: recraft_style: RecraftStyle = None, negative_prompt: str = None, recraft_controls: RecraftControls = None, + unique_id: Optional[str] = None, **kwargs, ): validate_string(prompt, strip_whitespace=False, max_length=1000) @@ -436,8 +441,15 @@ class RecraftTextToImageNode: ) response: RecraftImageGenerationResponse = operation.execute() images = [] + urls = [] for data in response.data: with handle_recraft_image_output(): + if unique_id and data.url: + urls.append(data.url) + urls_string = '\n'.join(urls) + PromptServer.instance.send_progress_text( + f"Result URL: {urls_string}", unique_id + ) image = bytesio_to_image_tensor( download_url_to_bytesio(data.url, timeout=1024) ) @@ -763,6 +775,7 @@ class RecraftTextToVectorNode: "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -775,6 +788,7 @@ class RecraftTextToVectorNode: seed, negative_prompt: str = None, recraft_controls: RecraftControls = None, + unique_id: Optional[str] = None, **kwargs, ): validate_string(prompt, strip_whitespace=False, max_length=1000) @@ -809,7 +823,14 @@ class RecraftTextToVectorNode: ) response: RecraftImageGenerationResponse = operation.execute() svg_data = [] + urls = [] for data in response.data: + if unique_id and data.url: + urls.append(data.url) + # Print result on each iteration in case of error + PromptServer.instance.send_progress_text( + f"Result URL: {' '.join(urls)}", unique_id + ) svg_data.append(download_url_to_bytesio(data.url, timeout=1024)) return (SVG(svg_data),) diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py new file mode 100644 index 000000000..67f90478c --- /dev/null +++ b/comfy_api_nodes/nodes_rodin.py @@ -0,0 +1,462 @@ +""" +ComfyUI X Rodin3D(Deemos) API Nodes + +Rodin API docs: https://developer.hyper3d.ai/ + +""" + +from __future__ import annotations +from inspect import cleandoc +from comfy.comfy_types.node_typing import IO +import folder_paths as comfy_paths +import requests +import os +import datetime +import shutil +import time +import io +import logging +import math +from PIL import Image +from comfy_api_nodes.apis.rodin_api import ( + Rodin3DGenerateRequest, + Rodin3DGenerateResponse, + Rodin3DCheckStatusRequest, + Rodin3DCheckStatusResponse, + Rodin3DDownloadRequest, + Rodin3DDownloadResponse, + JobStatus, +) +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, + PollingOperation, +) + + +COMMON_PARAMETERS = { + "Seed": ( + IO.INT, + { + "default":0, + "min":0, + "max":65535, + "display":"number" + } + ), + "Material_Type": ( + IO.COMBO, + { + "options": ["PBR", "Shaded"], + "default": "PBR" + } + ), + "Polygon_count": ( + IO.COMBO, + { + "options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"], + "default": "18K-Quad" + } + ) +} + +def create_task_error(response: Rodin3DGenerateResponse): + """Check if the response has error""" + return hasattr(response, "error") + + + +class Rodin3DAPI: + """ + Generate 3D Assets using Rodin API + """ + RETURN_TYPES = (IO.STRING,) + RETURN_NAMES = ("3D Model Path",) + CATEGORY = "api node/3d/Rodin" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "api_call" + API_NODE = True + + def tensor_to_filelike(self, tensor, max_pixels: int = 2048*2048): + """ + Converts a PyTorch tensor to a file-like object. + + Args: + - tensor (torch.Tensor): A tensor representing an image of shape (H, W, C) + where C is the number of channels (3 for RGB), H is height, and W is width. + + Returns: + - io.BytesIO: A file-like object containing the image data. + """ + array = tensor.cpu().numpy() + array = (array * 255).astype('uint8') + image = Image.fromarray(array, 'RGB') + + original_width, original_height = image.size + original_pixels = original_width * original_height + if original_pixels > max_pixels: + scale = math.sqrt(max_pixels / original_pixels) + new_width = int(original_width * scale) + new_height = int(original_height * scale) + else: + new_width, new_height = original_width, original_height + + if new_width != original_width or new_height != original_height: + image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) + + img_byte_arr = io.BytesIO() + image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression + img_byte_arr.seek(0) + return img_byte_arr + + def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str: + has_failed = any(job.status == JobStatus.Failed for job in response.jobs) + all_done = all(job.status == JobStatus.Done for job in response.jobs) + status_list = [str(job.status) for job in response.jobs] + logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}") + if has_failed: + logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.") + raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.") + elif all_done: + return "DONE" + else: + return "Generating" + + def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs): + if images == None: + raise Exception("Rodin 3D generate requires at least 1 image.") + if len(images) >= 5: + raise Exception("Rodin 3D generate requires up to 5 image.") + + path = "/proxy/rodin/api/v2/rodin" + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=path, + method=HttpMethod.POST, + request_model=Rodin3DGenerateRequest, + response_model=Rodin3DGenerateResponse, + ), + request=Rodin3DGenerateRequest( + seed=seed, + tier=tier, + material=material, + quality=quality, + mesh_mode=mesh_mode + ), + files=[ + ( + "images", + open(image, "rb") if isinstance(image, str) else self.tensor_to_filelike(image) + ) + for image in images if image is not None + ], + content_type = "multipart/form-data", + auth_kwargs=kwargs, + ) + + response = operation.execute() + + if create_task_error(response): + error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" + logging.error(error_message) + raise Exception(error_message) + + logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!") + subscription_key = response.jobs.subscription_key + task_uuid = response.uuid + logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}") + return task_uuid, subscription_key + + def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse: + + path = "/proxy/rodin/api/v2/status" + + poll_operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path = path, + method=HttpMethod.POST, + request_model=Rodin3DCheckStatusRequest, + response_model=Rodin3DCheckStatusResponse, + ), + request=Rodin3DCheckStatusRequest( + subscription_key = subscription_key + ), + completed_statuses=["DONE"], + failed_statuses=["FAILED"], + status_extractor=self.check_rodin_status, + poll_interval=3.0, + auth_kwargs=kwargs, + ) + + logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") + + return poll_operation.execute() + + + + def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse: + logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") + + path = "/proxy/rodin/api/v2/download" + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=path, + method=HttpMethod.POST, + request_model=Rodin3DDownloadRequest, + response_model=Rodin3DDownloadResponse, + ), + request=Rodin3DDownloadRequest( + task_uuid=uuid + ), + auth_kwargs=kwargs + ) + + return operation.execute() + + def GetQualityAndMode(self, PolyCount): + if PolyCount == "200K-Triangle": + mesh_mode = "Raw" + quality = "medium" + else: + mesh_mode = "Quad" + if PolyCount == "4K-Quad": + quality = "extra-low" + elif PolyCount == "8K-Quad": + quality = "low" + elif PolyCount == "18K-Quad": + quality = "medium" + elif PolyCount == "50K-Quad": + quality = "high" + else: + quality = "medium" + + return mesh_mode, quality + + def DownLoadFiles(self, Url_List): + Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) + os.makedirs(Save_path, exist_ok=True) + model_file_path = None + for Item in Url_List.list: + url = Item.url + file_name = Item.name + file_path = os.path.join(Save_path, file_name) + if file_path.endswith(".glb"): + model_file_path = file_path + logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}") + max_retries = 5 + for attempt in range(max_retries): + try: + with requests.get(url, stream=True) as r: + r.raise_for_status() + with open(file_path, "wb") as f: + shutil.copyfileobj(r.raw, f) + break + except Exception as e: + logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}") + if attempt < max_retries - 1: + logging.info("Retrying...") + time.sleep(2) + else: + logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.") + + return model_file_path + + +class Rodin3D_Regular(Rodin3DAPI): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "Images": + ( + IO.IMAGE, + { + "forceInput":True, + } + ) + }, + "optional": { + **COMMON_PARAMETERS + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + }, + } + + def api_call( + self, + Images, + Seed, + Material_Type, + Polygon_count, + **kwargs + ): + tier = "Regular" + num_images = Images.shape[0] + m_images = [] + for i in range(num_images): + m_images.append(Images[i]) + mesh_mode, quality = self.GetQualityAndMode(Polygon_count) + task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) + self.poll_for_task_status(subscription_key, **kwargs) + Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) + model = self.DownLoadFiles(Download_List) + + return (model,) + +class Rodin3D_Detail(Rodin3DAPI): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "Images": + ( + IO.IMAGE, + { + "forceInput":True, + } + ) + }, + "optional": { + **COMMON_PARAMETERS + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + }, + } + + def api_call( + self, + Images, + Seed, + Material_Type, + Polygon_count, + **kwargs + ): + tier = "Detail" + num_images = Images.shape[0] + m_images = [] + for i in range(num_images): + m_images.append(Images[i]) + mesh_mode, quality = self.GetQualityAndMode(Polygon_count) + task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) + self.poll_for_task_status(subscription_key, **kwargs) + Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) + model = self.DownLoadFiles(Download_List) + + return (model,) + +class Rodin3D_Smooth(Rodin3DAPI): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "Images": + ( + IO.IMAGE, + { + "forceInput":True, + } + ) + }, + "optional": { + **COMMON_PARAMETERS + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + }, + } + + def api_call( + self, + Images, + Seed, + Material_Type, + Polygon_count, + **kwargs + ): + tier = "Smooth" + num_images = Images.shape[0] + m_images = [] + for i in range(num_images): + m_images.append(Images[i]) + mesh_mode, quality = self.GetQualityAndMode(Polygon_count) + task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) + self.poll_for_task_status(subscription_key, **kwargs) + Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) + model = self.DownLoadFiles(Download_List) + + return (model,) + +class Rodin3D_Sketch(Rodin3DAPI): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "Images": + ( + IO.IMAGE, + { + "forceInput":True, + } + ) + }, + "optional": { + "Seed": + ( + IO.INT, + { + "default":0, + "min":0, + "max":65535, + "display":"number" + } + ) + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + }, + } + + def api_call( + self, + Images, + Seed, + **kwargs + ): + tier = "Sketch" + num_images = Images.shape[0] + m_images = [] + for i in range(num_images): + m_images.append(Images[i]) + material_type = "PBR" + quality = "medium" + mesh_mode = "Quad" + task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) + self.poll_for_task_status(subscription_key, **kwargs) + Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) + model = self.DownLoadFiles(Download_List) + + return (model,) + +# A dictionary that contains all nodes you want to export with their names +# NOTE: names should be globally unique +NODE_CLASS_MAPPINGS = { + "Rodin3D_Regular": Rodin3D_Regular, + "Rodin3D_Detail": Rodin3D_Detail, + "Rodin3D_Smooth": Rodin3D_Smooth, + "Rodin3D_Sketch": Rodin3D_Sketch, +} + +# A dictionary that contains the friendly/humanly readable titles for the nodes +NODE_DISPLAY_NAME_MAPPINGS = { + "Rodin3D_Regular": "Rodin 3D Generate - Regular Generate", + "Rodin3D_Detail": "Rodin 3D Generate - Detail Generate", + "Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate", + "Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate", +} diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py new file mode 100644 index 000000000..af4b321f9 --- /dev/null +++ b/comfy_api_nodes/nodes_runway.py @@ -0,0 +1,635 @@ +"""Runway API Nodes + +API Docs: + - https://docs.dev.runwayml.com/api/#tag/Task-management/paths/~1v1~1tasks~1%7Bid%7D/delete + +User Guides: + - https://help.runwayml.com/hc/en-us/sections/30265301423635-Gen-3-Alpha + - https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video + - https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo + - https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3 + +""" + +from typing import Union, Optional, Any +from enum import Enum + +import torch + +from comfy_api_nodes.apis import ( + RunwayImageToVideoRequest, + RunwayImageToVideoResponse, + RunwayTaskStatusResponse as TaskStatusResponse, + RunwayTaskStatusEnum as TaskStatus, + RunwayModelEnum as Model, + RunwayDurationEnum as Duration, + RunwayAspectRatioEnum as AspectRatio, + RunwayPromptImageObject, + RunwayPromptImageDetailedObject, + RunwayTextToImageRequest, + RunwayTextToImageResponse, + Model4, + ReferenceImage, + RunwayTextToImageAspectRatioEnum, +) +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, + PollingOperation, + EmptyRequest, +) +from comfy_api_nodes.apinode_utils import ( + upload_images_to_comfyapi, + download_url_to_video_output, + image_tensor_pair_to_batch, + validate_string, + download_url_to_image_tensor, +) +from comfy_api_nodes.mapper_utils import model_field_to_node_input +from comfy_api.input_impl import VideoFromFile +from comfy.comfy_types.node_typing import IO, ComfyNodeABC + +PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video" +PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image" +PATH_GET_TASK_STATUS = "/proxy/runway/tasks" + +AVERAGE_DURATION_I2V_SECONDS = 64 +AVERAGE_DURATION_FLF_SECONDS = 256 +AVERAGE_DURATION_T2I_SECONDS = 41 + + +class RunwayApiError(Exception): + """Base exception for Runway API errors.""" + + pass + + +class RunwayGen4TurboAspectRatio(str, Enum): + """Aspect ratios supported for Image to Video API when using gen4_turbo model.""" + + field_1280_720 = "1280:720" + field_720_1280 = "720:1280" + field_1104_832 = "1104:832" + field_832_1104 = "832:1104" + field_960_960 = "960:960" + field_1584_672 = "1584:672" + + +class RunwayGen3aAspectRatio(str, Enum): + """Aspect ratios supported for Image to Video API when using gen3a_turbo model.""" + + field_768_1280 = "768:1280" + field_1280_768 = "1280:768" + + +def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: + """Returns the video URL from the task status response if it exists.""" + if response.output and len(response.output) > 0: + return response.output[0] + return None + + +# TODO: replace with updated image validation utils (upstream) +def validate_input_image(image: torch.Tensor) -> bool: + """ + Validate the input image is within the size limits for the Runway API. + See: https://docs.dev.runwayml.com/assets/inputs/#common-error-reasons + """ + return image.shape[2] < 8000 and image.shape[1] < 8000 + + +def poll_until_finished( + auth_kwargs: dict[str, str], + api_endpoint: ApiEndpoint[Any, TaskStatusResponse], + estimated_duration: Optional[int] = None, + node_id: Optional[str] = None, +) -> TaskStatusResponse: + """Polls the Runway API endpoint until the task reaches a terminal state, then returns the response.""" + return PollingOperation( + poll_endpoint=api_endpoint, + completed_statuses=[ + TaskStatus.SUCCEEDED.value, + ], + failed_statuses=[ + TaskStatus.FAILED.value, + TaskStatus.CANCELLED.value, + ], + status_extractor=lambda response: (response.status.value), + auth_kwargs=auth_kwargs, + result_url_extractor=get_video_url_from_task_status, + estimated_duration=estimated_duration, + node_id=node_id, + progress_extractor=extract_progress_from_task_status, + ).execute() + + +def extract_progress_from_task_status( + response: TaskStatusResponse, +) -> Union[float, None]: + if hasattr(response, "progress") and response.progress is not None: + return response.progress * 100 + return None + + +def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: + """Returns the image URL from the task status response if it exists.""" + if response.output and len(response.output) > 0: + return response.output[0] + return None + + +class RunwayVideoGenNode(ComfyNodeABC): + """Runway Video Node Base.""" + + RETURN_TYPES = ("VIDEO",) + FUNCTION = "api_call" + CATEGORY = "api node/video/Runway" + API_NODE = True + + def validate_task_created(self, response: RunwayImageToVideoResponse) -> bool: + """ + Validate the task creation response from the Runway API matches + expected format. + """ + if not bool(response.id): + raise RunwayApiError("Invalid initial response from Runway API.") + return True + + def validate_response(self, response: RunwayImageToVideoResponse) -> bool: + """ + Validate the successful task status response from the Runway API + matches expected format. + """ + if not response.output or len(response.output) == 0: + raise RunwayApiError( + "Runway task succeeded but no video data found in response." + ) + return True + + def get_response( + self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None + ) -> RunwayImageToVideoResponse: + """Poll the task status until it is finished then get the response.""" + return poll_until_finished( + auth_kwargs, + ApiEndpoint( + path=f"{PATH_GET_TASK_STATUS}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=TaskStatusResponse, + ), + estimated_duration=AVERAGE_DURATION_FLF_SECONDS, + node_id=node_id, + ) + + def generate_video( + self, + request: RunwayImageToVideoRequest, + auth_kwargs: dict[str, str], + node_id: Optional[str] = None, + ) -> tuple[VideoFromFile]: + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_IMAGE_TO_VIDEO, + method=HttpMethod.POST, + request_model=RunwayImageToVideoRequest, + response_model=RunwayImageToVideoResponse, + ), + request=request, + auth_kwargs=auth_kwargs, + ) + + initial_response = initial_operation.execute() + self.validate_task_created(initial_response) + task_id = initial_response.id + + final_response = self.get_response(task_id, auth_kwargs, node_id) + self.validate_response(final_response) + + video_url = get_video_url_from_task_status(final_response) + return (download_url_to_video_output(video_url),) + + +class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode): + """Runway Image to Video Node using Gen3a Turbo model.""" + + DESCRIPTION = "Generate a video from a single starting frame using Gen3a Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo." + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": model_field_to_node_input( + IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True + ), + "start_frame": ( + IO.IMAGE, + {"tooltip": "Start frame to be used for the video"}, + ), + "duration": model_field_to_node_input( + IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration + ), + "ratio": model_field_to_node_input( + IO.COMBO, + RunwayImageToVideoRequest, + "ratio", + enum_type=RunwayGen3aAspectRatio, + ), + "seed": model_field_to_node_input( + IO.INT, + RunwayImageToVideoRequest, + "seed", + control_after_generate=True, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + def api_call( + self, + prompt: str, + start_frame: torch.Tensor, + duration: str, + ratio: str, + seed: int, + unique_id: Optional[str] = None, + **kwargs, + ) -> tuple[VideoFromFile]: + # Validate inputs + validate_string(prompt, min_length=1) + validate_input_image(start_frame) + + # Upload image + download_urls = upload_images_to_comfyapi( + start_frame, + max_images=1, + mime_type="image/png", + auth_kwargs=kwargs, + ) + if len(download_urls) != 1: + raise RunwayApiError("Failed to upload one or more images to comfy api.") + + return self.generate_video( + RunwayImageToVideoRequest( + promptText=prompt, + seed=seed, + model=Model("gen3a_turbo"), + duration=Duration(duration), + ratio=AspectRatio(ratio), + promptImage=RunwayPromptImageObject( + root=[ + RunwayPromptImageDetailedObject( + uri=str(download_urls[0]), position="first" + ) + ] + ), + ), + auth_kwargs=kwargs, + node_id=unique_id, + ) + + +class RunwayImageToVideoNodeGen4(RunwayVideoGenNode): + """Runway Image to Video Node using Gen4 Turbo model.""" + + DESCRIPTION = "Generate a video from a single starting frame using Gen4 Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video." + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": model_field_to_node_input( + IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True + ), + "start_frame": ( + IO.IMAGE, + {"tooltip": "Start frame to be used for the video"}, + ), + "duration": model_field_to_node_input( + IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration + ), + "ratio": model_field_to_node_input( + IO.COMBO, + RunwayImageToVideoRequest, + "ratio", + enum_type=RunwayGen4TurboAspectRatio, + ), + "seed": model_field_to_node_input( + IO.INT, + RunwayImageToVideoRequest, + "seed", + control_after_generate=True, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + def api_call( + self, + prompt: str, + start_frame: torch.Tensor, + duration: str, + ratio: str, + seed: int, + unique_id: Optional[str] = None, + **kwargs, + ) -> tuple[VideoFromFile]: + # Validate inputs + validate_string(prompt, min_length=1) + validate_input_image(start_frame) + + # Upload image + download_urls = upload_images_to_comfyapi( + start_frame, + max_images=1, + mime_type="image/png", + auth_kwargs=kwargs, + ) + if len(download_urls) != 1: + raise RunwayApiError("Failed to upload one or more images to comfy api.") + + return self.generate_video( + RunwayImageToVideoRequest( + promptText=prompt, + seed=seed, + model=Model("gen4_turbo"), + duration=Duration(duration), + ratio=AspectRatio(ratio), + promptImage=RunwayPromptImageObject( + root=[ + RunwayPromptImageDetailedObject( + uri=str(download_urls[0]), position="first" + ) + ] + ), + ), + auth_kwargs=kwargs, + node_id=unique_id, + ) + + +class RunwayFirstLastFrameNode(RunwayVideoGenNode): + """Runway First-Last Frame Node.""" + + DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3." + + def get_response( + self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None + ) -> RunwayImageToVideoResponse: + return poll_until_finished( + auth_kwargs, + ApiEndpoint( + path=f"{PATH_GET_TASK_STATUS}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=TaskStatusResponse, + ), + estimated_duration=AVERAGE_DURATION_FLF_SECONDS, + node_id=node_id, + ) + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": model_field_to_node_input( + IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True + ), + "start_frame": ( + IO.IMAGE, + {"tooltip": "Start frame to be used for the video"}, + ), + "end_frame": ( + IO.IMAGE, + { + "tooltip": "End frame to be used for the video. Supported for gen3a_turbo only." + }, + ), + "duration": model_field_to_node_input( + IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration + ), + "ratio": model_field_to_node_input( + IO.COMBO, + RunwayImageToVideoRequest, + "ratio", + enum_type=RunwayGen3aAspectRatio, + ), + "seed": model_field_to_node_input( + IO.INT, + RunwayImageToVideoRequest, + "seed", + control_after_generate=True, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "unique_id": "UNIQUE_ID", + "comfy_api_key": "API_KEY_COMFY_ORG", + }, + } + + def api_call( + self, + prompt: str, + start_frame: torch.Tensor, + end_frame: torch.Tensor, + duration: str, + ratio: str, + seed: int, + unique_id: Optional[str] = None, + **kwargs, + ) -> tuple[VideoFromFile]: + # Validate inputs + validate_string(prompt, min_length=1) + validate_input_image(start_frame) + validate_input_image(end_frame) + + # Upload images + stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) + download_urls = upload_images_to_comfyapi( + stacked_input_images, + max_images=2, + mime_type="image/png", + auth_kwargs=kwargs, + ) + if len(download_urls) != 2: + raise RunwayApiError("Failed to upload one or more images to comfy api.") + + return self.generate_video( + RunwayImageToVideoRequest( + promptText=prompt, + seed=seed, + model=Model("gen3a_turbo"), + duration=Duration(duration), + ratio=AspectRatio(ratio), + promptImage=RunwayPromptImageObject( + root=[ + RunwayPromptImageDetailedObject( + uri=str(download_urls[0]), position="first" + ), + RunwayPromptImageDetailedObject( + uri=str(download_urls[1]), position="last" + ), + ] + ), + ), + auth_kwargs=kwargs, + node_id=unique_id, + ) + + +class RunwayTextToImageNode(ComfyNodeABC): + """Runway Text to Image Node.""" + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "api_call" + CATEGORY = "api node/image/Runway" + API_NODE = True + DESCRIPTION = "Generate an image from a text prompt using Runway's Gen 4 model. You can also include reference images to guide the generation." + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": model_field_to_node_input( + IO.STRING, RunwayTextToImageRequest, "promptText", multiline=True + ), + "ratio": model_field_to_node_input( + IO.COMBO, + RunwayTextToImageRequest, + "ratio", + enum_type=RunwayTextToImageAspectRatioEnum, + ), + }, + "optional": { + "reference_image": ( + IO.IMAGE, + {"tooltip": "Optional reference image to guide the generation"}, + ) + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + def validate_task_created(self, response: RunwayTextToImageResponse) -> bool: + """ + Validate the task creation response from the Runway API matches + expected format. + """ + if not bool(response.id): + raise RunwayApiError("Invalid initial response from Runway API.") + return True + + def validate_response(self, response: TaskStatusResponse) -> bool: + """ + Validate the successful task status response from the Runway API + matches expected format. + """ + if not response.output or len(response.output) == 0: + raise RunwayApiError( + "Runway task succeeded but no image data found in response." + ) + return True + + def get_response( + self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None + ) -> TaskStatusResponse: + """Poll the task status until it is finished then get the response.""" + return poll_until_finished( + auth_kwargs, + ApiEndpoint( + path=f"{PATH_GET_TASK_STATUS}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=TaskStatusResponse, + ), + estimated_duration=AVERAGE_DURATION_T2I_SECONDS, + node_id=node_id, + ) + + def api_call( + self, + prompt: str, + ratio: str, + reference_image: Optional[torch.Tensor] = None, + unique_id: Optional[str] = None, + **kwargs, + ) -> tuple[torch.Tensor]: + # Validate inputs + validate_string(prompt, min_length=1) + + # Prepare reference images if provided + reference_images = None + if reference_image is not None: + validate_input_image(reference_image) + download_urls = upload_images_to_comfyapi( + reference_image, + max_images=1, + mime_type="image/png", + auth_kwargs=kwargs, + ) + if len(download_urls) != 1: + raise RunwayApiError("Failed to upload reference image to comfy api.") + + reference_images = [ReferenceImage(uri=str(download_urls[0]))] + + # Create request + request = RunwayTextToImageRequest( + promptText=prompt, + model=Model4.gen4_image, + ratio=ratio, + referenceImages=reference_images, + ) + + # Execute initial request + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_TEXT_TO_IMAGE, + method=HttpMethod.POST, + request_model=RunwayTextToImageRequest, + response_model=RunwayTextToImageResponse, + ), + request=request, + auth_kwargs=kwargs, + ) + + initial_response = initial_operation.execute() + self.validate_task_created(initial_response) + task_id = initial_response.id + + # Poll for completion + final_response = self.get_response( + task_id, auth_kwargs=kwargs, node_id=unique_id + ) + self.validate_response(final_response) + + # Download and return image + image_url = get_image_url_from_task_status(final_response) + return (download_url_to_image_tensor(image_url),) + + +NODE_CLASS_MAPPINGS = { + "RunwayFirstLastFrameNode": RunwayFirstLastFrameNode, + "RunwayImageToVideoNodeGen3a": RunwayImageToVideoNodeGen3a, + "RunwayImageToVideoNodeGen4": RunwayImageToVideoNodeGen4, + "RunwayTextToImageNode": RunwayTextToImageNode, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "RunwayFirstLastFrameNode": "Runway First-Last-Frame to Video", + "RunwayImageToVideoNodeGen3a": "Runway Image to Video (Gen3a Turbo)", + "RunwayImageToVideoNodeGen4": "Runway Image to Video (Gen4 Turbo)", + "RunwayTextToImageNode": "Runway Text to Image", +} diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py new file mode 100644 index 000000000..65f3b21f5 --- /dev/null +++ b/comfy_api_nodes/nodes_tripo.py @@ -0,0 +1,574 @@ +import os +from folder_paths import get_output_directory +from comfy_api_nodes.mapper_utils import model_field_to_node_input +from comfy.comfy_types.node_typing import IO +from comfy_api_nodes.apis import ( + TripoOrientation, + TripoModelVersion, +) +from comfy_api_nodes.apis.tripo_api import ( + TripoTaskType, + TripoStyle, + TripoFileReference, + TripoFileEmptyReference, + TripoUrlReference, + TripoTaskResponse, + TripoTaskStatus, + TripoTextToModelRequest, + TripoImageToModelRequest, + TripoMultiviewToModelRequest, + TripoTextureModelRequest, + TripoRefineModelRequest, + TripoAnimateRigRequest, + TripoAnimateRetargetRequest, + TripoConvertModelRequest, +) + +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, + PollingOperation, + EmptyRequest, +) +from comfy_api_nodes.apinode_utils import ( + upload_images_to_comfyapi, + download_url_to_bytesio, +) + + +def upload_image_to_tripo(image, **kwargs): + urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs) + return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg")) + +def get_model_url_from_response(response: TripoTaskResponse) -> str: + if response.data is not None: + for key in ["pbr_model", "model", "base_model"]: + if getattr(response.data.output, key, None) is not None: + return getattr(response.data.output, key) + raise RuntimeError(f"Failed to get model url from response: {response}") + + +def poll_until_finished( + kwargs: dict[str, str], + response: TripoTaskResponse, +) -> tuple[str, str]: + """Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response.""" + if response.code != 0: + raise RuntimeError(f"Failed to generate mesh: {response.error}") + task_id = response.data.task_id + response_poll = PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"/proxy/tripo/v2/openapi/task/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=TripoTaskResponse, + ), + completed_statuses=[TripoTaskStatus.SUCCESS], + failed_statuses=[ + TripoTaskStatus.FAILED, + TripoTaskStatus.CANCELLED, + TripoTaskStatus.UNKNOWN, + TripoTaskStatus.BANNED, + TripoTaskStatus.EXPIRED, + ], + status_extractor=lambda x: x.data.status, + auth_kwargs=kwargs, + node_id=kwargs["unique_id"], + result_url_extractor=get_model_url_from_response, + progress_extractor=lambda x: x.data.progress, + ).execute() + if response_poll.data.status == TripoTaskStatus.SUCCESS: + url = get_model_url_from_response(response_poll) + bytesio = download_url_to_bytesio(url) + # Save the downloaded model file + model_file = f"tripo_model_{task_id}.glb" + with open(os.path.join(get_output_directory(), model_file), "wb") as f: + f.write(bytesio.getvalue()) + return model_file, task_id + raise RuntimeError(f"Failed to generate mesh: {response_poll}") + +class TripoTextToModelNode: + """ + Generates 3D models synchronously based on a text prompt using Tripo's API. + """ + AVERAGE_DURATION = 80 + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": ("STRING", {"multiline": True}), + }, + "optional": { + "negative_prompt": ("STRING", {"multiline": True}), + "model_version": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "model_version", enum_type=TripoModelVersion), + "style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"), + "texture": ("BOOLEAN", {"default": True}), + "pbr": ("BOOLEAN", {"default": True}), + "image_seed": ("INT", {"default": 42}), + "model_seed": ("INT", {"default": 42}), + "texture_seed": ("INT", {"default": 42}), + "texture_quality": (["standard", "detailed"], {"default": "standard"}), + "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), + "quad": ("BOOLEAN", {"default": False}) + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) + RETURN_NAMES = ("model_file", "model task_id") + FUNCTION = "generate_mesh" + CATEGORY = "api node/3d/Tripo" + API_NODE = True + OUTPUT_NODE = True + + def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): + style_enum = None if style == "None" else style + if not prompt: + raise RuntimeError("Prompt is required") + response = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/tripo/v2/openapi/task", + method=HttpMethod.POST, + request_model=TripoTextToModelRequest, + response_model=TripoTaskResponse, + ), + request=TripoTextToModelRequest( + type=TripoTaskType.TEXT_TO_MODEL, + prompt=prompt, + negative_prompt=negative_prompt if negative_prompt else None, + model_version=model_version, + style=style_enum, + texture=texture, + pbr=pbr, + image_seed=image_seed, + model_seed=model_seed, + texture_seed=texture_seed, + texture_quality=texture_quality, + face_limit=face_limit, + auto_size=True, + quad=quad + ), + auth_kwargs=kwargs, + ).execute() + return poll_until_finished(kwargs, response) + +class TripoImageToModelNode: + """ + Generates 3D models synchronously based on a single image using Tripo's API. + """ + AVERAGE_DURATION = 80 + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + }, + "optional": { + "model_version": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "model_version", enum_type=TripoModelVersion), + "style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"), + "texture": ("BOOLEAN", {"default": True}), + "pbr": ("BOOLEAN", {"default": True}), + "model_seed": ("INT", {"default": 42}), + "orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation), + "texture_seed": ("INT", {"default": 42}), + "texture_quality": (["standard", "detailed"], {"default": "standard"}), + "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}), + "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), + "quad": ("BOOLEAN", {"default": False}) + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) + RETURN_NAMES = ("model_file", "model task_id") + FUNCTION = "generate_mesh" + CATEGORY = "api node/3d/Tripo" + API_NODE = True + OUTPUT_NODE = True + + def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): + style_enum = None if style == "None" else style + if image is None: + raise RuntimeError("Image is required") + tripo_file = upload_image_to_tripo(image, **kwargs) + response = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/tripo/v2/openapi/task", + method=HttpMethod.POST, + request_model=TripoImageToModelRequest, + response_model=TripoTaskResponse, + ), + request=TripoImageToModelRequest( + type=TripoTaskType.IMAGE_TO_MODEL, + file=tripo_file, + model_version=model_version, + style=style_enum, + texture=texture, + pbr=pbr, + model_seed=model_seed, + orientation=orientation, + texture_alignment=texture_alignment, + texture_seed=texture_seed, + texture_quality=texture_quality, + face_limit=face_limit, + auto_size=True, + quad=quad + ), + auth_kwargs=kwargs, + ).execute() + return poll_until_finished(kwargs, response) + +class TripoMultiviewToModelNode: + """ + Generates 3D models synchronously based on up to four images (front, left, back, right) using Tripo's API. + """ + AVERAGE_DURATION = 80 + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + }, + "optional": { + "image_left": ("IMAGE",), + "image_back": ("IMAGE",), + "image_right": ("IMAGE",), + "model_version": model_field_to_node_input(IO.COMBO, TripoMultiviewToModelRequest, "model_version", enum_type=TripoModelVersion), + "orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation), + "texture": ("BOOLEAN", {"default": True}), + "pbr": ("BOOLEAN", {"default": True}), + "model_seed": ("INT", {"default": 42}), + "texture_seed": ("INT", {"default": 42}), + "texture_quality": (["standard", "detailed"], {"default": "standard"}), + "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}), + "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), + "quad": ("BOOLEAN", {"default": False}) + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) + RETURN_NAMES = ("model_file", "model task_id") + FUNCTION = "generate_mesh" + CATEGORY = "api node/3d/Tripo" + API_NODE = True + OUTPUT_NODE = True + + def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs): + if image is None: + raise RuntimeError("front image for multiview is required") + images = [] + image_dict = { + "image": image, + "image_left": image_left, + "image_back": image_back, + "image_right": image_right + } + if image_left is None and image_back is None and image_right is None: + raise RuntimeError("At least one of left, back, or right image must be provided for multiview") + for image_name in ["image", "image_left", "image_back", "image_right"]: + image_ = image_dict[image_name] + if image_ is not None: + tripo_file = upload_image_to_tripo(image_, **kwargs) + images.append(tripo_file) + else: + images.append(TripoFileEmptyReference()) + response = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/tripo/v2/openapi/task", + method=HttpMethod.POST, + request_model=TripoMultiviewToModelRequest, + response_model=TripoTaskResponse, + ), + request=TripoMultiviewToModelRequest( + type=TripoTaskType.MULTIVIEW_TO_MODEL, + files=images, + model_version=model_version, + orientation=orientation, + texture=texture, + pbr=pbr, + model_seed=model_seed, + texture_seed=texture_seed, + texture_quality=texture_quality, + texture_alignment=texture_alignment, + face_limit=face_limit, + quad=quad, + ), + auth_kwargs=kwargs, + ).execute() + return poll_until_finished(kwargs, response) + +class TripoTextureNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model_task_id": ("MODEL_TASK_ID",), + }, + "optional": { + "texture": ("BOOLEAN", {"default": True}), + "pbr": ("BOOLEAN", {"default": True}), + "texture_seed": ("INT", {"default": 42}), + "texture_quality": (["standard", "detailed"], {"default": "standard"}), + "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) + RETURN_NAMES = ("model_file", "model task_id") + FUNCTION = "generate_mesh" + CATEGORY = "api node/3d/Tripo" + API_NODE = True + OUTPUT_NODE = True + AVERAGE_DURATION = 80 + + def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs): + response = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/tripo/v2/openapi/task", + method=HttpMethod.POST, + request_model=TripoTextureModelRequest, + response_model=TripoTaskResponse, + ), + request=TripoTextureModelRequest( + original_model_task_id=model_task_id, + texture=texture, + pbr=pbr, + texture_seed=texture_seed, + texture_quality=texture_quality, + texture_alignment=texture_alignment + ), + auth_kwargs=kwargs, + ).execute() + return poll_until_finished(kwargs, response) + + +class TripoRefineNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model_task_id": ("MODEL_TASK_ID", { + "tooltip": "Must be a v1.4 Tripo model" + }), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + DESCRIPTION = "Refine a draft model created by v1.4 Tripo models only." + + RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) + RETURN_NAMES = ("model_file", "model task_id") + FUNCTION = "generate_mesh" + CATEGORY = "api node/3d/Tripo" + API_NODE = True + OUTPUT_NODE = True + AVERAGE_DURATION = 240 + + def generate_mesh(self, model_task_id, **kwargs): + response = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/tripo/v2/openapi/task", + method=HttpMethod.POST, + request_model=TripoRefineModelRequest, + response_model=TripoTaskResponse, + ), + request=TripoRefineModelRequest( + draft_model_task_id=model_task_id + ), + auth_kwargs=kwargs, + ).execute() + return poll_until_finished(kwargs, response) + + +class TripoRigNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "original_model_task_id": ("MODEL_TASK_ID",), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = ("STRING", "RIG_TASK_ID") + RETURN_NAMES = ("model_file", "rig task_id") + FUNCTION = "generate_mesh" + CATEGORY = "api node/3d/Tripo" + API_NODE = True + OUTPUT_NODE = True + AVERAGE_DURATION = 180 + + def generate_mesh(self, original_model_task_id, **kwargs): + response = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/tripo/v2/openapi/task", + method=HttpMethod.POST, + request_model=TripoAnimateRigRequest, + response_model=TripoTaskResponse, + ), + request=TripoAnimateRigRequest( + original_model_task_id=original_model_task_id, + out_format="glb", + spec="tripo" + ), + auth_kwargs=kwargs, + ).execute() + return poll_until_finished(kwargs, response) + +class TripoRetargetNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "original_model_task_id": ("RIG_TASK_ID",), + "animation": ([ + "preset:idle", + "preset:walk", + "preset:climb", + "preset:jump", + "preset:slash", + "preset:shoot", + "preset:hurt", + "preset:fall", + "preset:turn", + ],), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = ("STRING", "RETARGET_TASK_ID") + RETURN_NAMES = ("model_file", "retarget task_id") + FUNCTION = "generate_mesh" + CATEGORY = "api node/3d/Tripo" + API_NODE = True + OUTPUT_NODE = True + AVERAGE_DURATION = 30 + + def generate_mesh(self, animation, original_model_task_id, **kwargs): + response = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/tripo/v2/openapi/task", + method=HttpMethod.POST, + request_model=TripoAnimateRetargetRequest, + response_model=TripoTaskResponse, + ), + request=TripoAnimateRetargetRequest( + original_model_task_id=original_model_task_id, + animation=animation, + out_format="glb", + bake_animation=True + ), + auth_kwargs=kwargs, + ).execute() + return poll_until_finished(kwargs, response) + +class TripoConversionNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "original_model_task_id": ("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID",), + "format": (["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"],), + }, + "optional": { + "quad": ("BOOLEAN", {"default": False}), + "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), + "texture_size": ("INT", {"min": 128, "max": 4096, "default": 4096}), + "texture_format": (["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], {"default": "JPEG"}) + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + @classmethod + def VALIDATE_INPUTS(cls, input_types): + # The min and max of input1 and input2 are still validated because + # we didn't take `input1` or `input2` as arguments + if input_types["original_model_task_id"] not in ("MODEL_TASK_ID", "RIG_TASK_ID", "RETARGET_TASK_ID"): + return "original_model_task_id must be MODEL_TASK_ID, RIG_TASK_ID or RETARGET_TASK_ID type" + return True + + RETURN_TYPES = () + FUNCTION = "generate_mesh" + CATEGORY = "api node/3d/Tripo" + API_NODE = True + OUTPUT_NODE = True + AVERAGE_DURATION = 30 + + def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs): + if not original_model_task_id: + raise RuntimeError("original_model_task_id is required") + response = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/tripo/v2/openapi/task", + method=HttpMethod.POST, + request_model=TripoConvertModelRequest, + response_model=TripoTaskResponse, + ), + request=TripoConvertModelRequest( + original_model_task_id=original_model_task_id, + format=format, + quad=quad if quad else None, + face_limit=face_limit if face_limit != -1 else None, + texture_size=texture_size if texture_size != 4096 else None, + texture_format=texture_format if texture_format != "JPEG" else None + ), + auth_kwargs=kwargs, + ).execute() + return poll_until_finished(kwargs, response) + +NODE_CLASS_MAPPINGS = { + "TripoTextToModelNode": TripoTextToModelNode, + "TripoImageToModelNode": TripoImageToModelNode, + "TripoMultiviewToModelNode": TripoMultiviewToModelNode, + "TripoTextureNode": TripoTextureNode, + "TripoRefineNode": TripoRefineNode, + "TripoRigNode": TripoRigNode, + "TripoRetargetNode": TripoRetargetNode, + "TripoConversionNode": TripoConversionNode, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "TripoTextToModelNode": "Tripo: Text to Model", + "TripoImageToModelNode": "Tripo: Image to Model", + "TripoMultiviewToModelNode": "Tripo: Multiview to Model", + "TripoTextureNode": "Tripo: Texture model", + "TripoRefineNode": "Tripo: Refine Draft model", + "TripoRigNode": "Tripo: Rig model", + "TripoRetargetNode": "Tripo: Retarget rigged model", + "TripoConversionNode": "Tripo: Convert model", +} diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 2740179c8..df846d5dd 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -3,6 +3,7 @@ import logging import base64 import requests import torch +from typing import Optional from comfy.comfy_types.node_typing import IO, ComfyNodeABC from comfy_api.input_impl.video_types import VideoFromFile @@ -24,6 +25,8 @@ from comfy_api_nodes.apinode_utils import ( tensor_to_base64_string ) +AVERAGE_DURATION_VIDEO_GEN = 32 + def convert_image_to_base64(image: torch.Tensor): if image is None: return None @@ -31,6 +34,22 @@ def convert_image_to_base64(image: torch.Tensor): scaled_image = downscale_image_tensor(image, total_pixels=2048*2048) return tensor_to_base64_string(scaled_image) + +def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]: + if ( + poll_response.response + and hasattr(poll_response.response, "videos") + and poll_response.response.videos + and len(poll_response.response.videos) > 0 + ): + video = poll_response.response.videos[0] + else: + return None + if hasattr(video, "gcsUri") and video.gcsUri: + return str(video.gcsUri) + return None + + class VeoVideoGenerationNode(ComfyNodeABC): """ Generates videos from text prompts using Google's Veo API. @@ -115,6 +134,7 @@ class VeoVideoGenerationNode(ComfyNodeABC): "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", }, } @@ -134,6 +154,7 @@ class VeoVideoGenerationNode(ComfyNodeABC): person_generation="ALLOW", seed=0, image=None, + unique_id: Optional[str] = None, **kwargs, ): # Prepare the instances for the request @@ -215,7 +236,10 @@ class VeoVideoGenerationNode(ComfyNodeABC): operationName=operation_name ), auth_kwargs=kwargs, - poll_interval=5.0 + poll_interval=5.0, + result_url_extractor=get_video_url_from_response, + node_id=unique_id, + estimated_duration=AVERAGE_DURATION_VIDEO_GEN, ) # Execute the polling operation diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py new file mode 100644 index 000000000..031b9fbd3 --- /dev/null +++ b/comfy_api_nodes/util/validation_utils.py @@ -0,0 +1,100 @@ +import logging +from typing import Optional + +import torch +from comfy_api.input.video_types import VideoInput + + +def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]: + if len(image.shape) == 4: + return image.shape[1], image.shape[2] + elif len(image.shape) == 3: + return image.shape[0], image.shape[1] + else: + raise ValueError("Invalid image tensor shape.") + + +def validate_image_dimensions( + image: torch.Tensor, + min_width: Optional[int] = None, + max_width: Optional[int] = None, + min_height: Optional[int] = None, + max_height: Optional[int] = None, +): + height, width = get_image_dimensions(image) + + if min_width is not None and width < min_width: + raise ValueError(f"Image width must be at least {min_width}px, got {width}px") + if max_width is not None and width > max_width: + raise ValueError(f"Image width must be at most {max_width}px, got {width}px") + if min_height is not None and height < min_height: + raise ValueError( + f"Image height must be at least {min_height}px, got {height}px" + ) + if max_height is not None and height > max_height: + raise ValueError(f"Image height must be at most {max_height}px, got {height}px") + + +def validate_image_aspect_ratio( + image: torch.Tensor, + min_aspect_ratio: Optional[float] = None, + max_aspect_ratio: Optional[float] = None, +): + width, height = get_image_dimensions(image) + aspect_ratio = width / height + + if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio: + raise ValueError( + f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}" + ) + if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio: + raise ValueError( + f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}" + ) + + +def validate_video_dimensions( + video: VideoInput, + min_width: Optional[int] = None, + max_width: Optional[int] = None, + min_height: Optional[int] = None, + max_height: Optional[int] = None, +): + try: + width, height = video.get_dimensions() + except Exception as e: + logging.error("Error getting dimensions of video: %s", e) + return + + if min_width is not None and width < min_width: + raise ValueError(f"Video width must be at least {min_width}px, got {width}px") + if max_width is not None and width > max_width: + raise ValueError(f"Video width must be at most {max_width}px, got {width}px") + if min_height is not None and height < min_height: + raise ValueError( + f"Video height must be at least {min_height}px, got {height}px" + ) + if max_height is not None and height > max_height: + raise ValueError(f"Video height must be at most {max_height}px, got {height}px") + + +def validate_video_duration( + video: VideoInput, + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, +): + try: + duration = video.get_duration() + except Exception as e: + logging.error("Error getting duration of video: %s", e) + return + + epsilon = 0.0001 + if min_duration is not None and min_duration - epsilon > duration: + raise ValueError( + f"Video duration must be at least {min_duration}s, got {duration}s" + ) + if max_duration is not None and duration > max_duration + epsilon: + raise ValueError( + f"Video duration must be at most {max_duration}s, got {duration}s" + ) diff --git a/comfy_config/config_parser.py b/comfy_config/config_parser.py new file mode 100644 index 000000000..a9cbd94dd --- /dev/null +++ b/comfy_config/config_parser.py @@ -0,0 +1,97 @@ +import os +from pathlib import Path +from typing import Optional + +from pydantic_settings import PydanticBaseSettingsSource, TomlConfigSettingsSource + +from comfy_config.types import ( + ComfyConfig, + ProjectConfig, + PyProjectConfig, + PyProjectSettings +) + +""" +Extract configuration from a custom node directory's pyproject.toml file or a Python file. + +This function reads and parses the pyproject.toml file in the specified directory +to extract project and ComfyUI-specific configuration information. If no +pyproject.toml file is found, it creates a minimal configuration using the +folder name as the project name. If a Python file is provided, it uses the +file name (without extension) as the project name. + +Args: + path (str): Path to the directory containing the pyproject.toml file, or + path to a .py file. If pyproject.toml doesn't exist in a directory, + the folder name will be used as the default project name. If a .py + file is provided, the filename (without .py extension) will be used + as the project name. + +Returns: + Optional[PyProjectConfig]: A PyProjectConfig object containing: + - project: Basic project information (name, version, dependencies, etc.) + - tool_comfy: ComfyUI-specific configuration (publisher_id, models, etc.) + Returns None if configuration extraction fails or if the provided file + is not a Python file. + +Notes: + - If pyproject.toml is missing in a directory, creates a default config with folder name + - If a .py file is provided, creates a default config with filename (without extension) + - Returns None for non-Python files + +Example: + >>> from comfy_config import config_parser + >>> # For directory + >>> custom_node_dir = os.path.dirname(os.path.realpath(__file__)) + >>> project_config = config_parser.extract_node_configuration(custom_node_dir) + >>> print(project_config.project.name) # "my_custom_node" or name from pyproject.toml + >>> + >>> # For single-file Python node file + >>> py_file_path = os.path.realpath(__file__) # "/path/to/my_node.py" + >>> project_config = config_parser.extract_node_configuration(py_file_path) + >>> print(project_config.project.name) # "my_node" +""" +def extract_node_configuration(path) -> Optional[PyProjectConfig]: + if os.path.isfile(path): + file_path = Path(path) + + if file_path.suffix.lower() != '.py': + return None + + project_name = file_path.stem + project = ProjectConfig(name=project_name) + comfy = ComfyConfig() + return PyProjectConfig(project=project, tool_comfy=comfy) + + folder_name = os.path.basename(path) + toml_path = Path(path) / "pyproject.toml" + + if not toml_path.exists(): + project = ProjectConfig(name=folder_name) + comfy = ComfyConfig() + return PyProjectConfig(project=project, tool_comfy=comfy) + + raw_settings = load_pyproject_settings(toml_path) + + project_data = raw_settings.project + + tool_data = raw_settings.tool + comfy_data = tool_data.get("comfy", {}) if tool_data else {} + + return PyProjectConfig(project=project_data, tool_comfy=comfy_data) + + +def load_pyproject_settings(toml_path: Path) -> PyProjectSettings: + class PyProjectLoader(PyProjectSettings): + @classmethod + def settings_customise_sources( + cls, + settings_cls, + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ): + return (TomlConfigSettingsSource(settings_cls, toml_path),) + + return PyProjectLoader() diff --git a/comfy_config/types.py b/comfy_config/types.py new file mode 100644 index 000000000..5222cc59b --- /dev/null +++ b/comfy_config/types.py @@ -0,0 +1,93 @@ +from pydantic import BaseModel, Field, field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict +from typing import List, Optional + +# IMPORTANT: The type definitions specified in pyproject.toml for custom nodes +# must remain synchronized with the corresponding files in the https://github.com/Comfy-Org/comfy-cli/blob/main/comfy_cli/registry/types.py. +# Any changes to one must be reflected in the other to maintain consistency. + +class NodeVersion(BaseModel): + changelog: str + dependencies: List[str] + deprecated: bool + id: str + version: str + download_url: str + + +class Node(BaseModel): + id: str + name: str + description: str + author: Optional[str] = None + license: Optional[str] = None + icon: Optional[str] = None + repository: Optional[str] = None + tags: List[str] = Field(default_factory=list) + latest_version: Optional[NodeVersion] = None + + +class PublishNodeVersionResponse(BaseModel): + node_version: NodeVersion + signedUrl: str + + +class URLs(BaseModel): + homepage: str = Field(default="", alias="Homepage") + documentation: str = Field(default="", alias="Documentation") + repository: str = Field(default="", alias="Repository") + issues: str = Field(default="", alias="Issues") + + +class Model(BaseModel): + location: str + model_url: str + + +class ComfyConfig(BaseModel): + publisher_id: str = Field(default="", alias="PublisherId") + display_name: str = Field(default="", alias="DisplayName") + icon: str = Field(default="", alias="Icon") + models: List[Model] = Field(default_factory=list, alias="Models") + includes: List[str] = Field(default_factory=list) + web: Optional[str] = None + + +class License(BaseModel): + file: str = "" + text: str = "" + + +class ProjectConfig(BaseModel): + name: str = "" + description: str = "" + version: str = "1.0.0" + requires_python: str = Field(default=">= 3.9", alias="requires-python") + dependencies: List[str] = Field(default_factory=list) + license: License = Field(default_factory=License) + urls: URLs = Field(default_factory=URLs) + + @field_validator('license', mode='before') + @classmethod + def validate_license(cls, v): + if isinstance(v, str): + return License(text=v) + elif isinstance(v, dict): + return License(**v) + elif isinstance(v, License): + return v + else: + return License() + + +class PyProjectConfig(BaseModel): + project: ProjectConfig = Field(default_factory=ProjectConfig) + tool_comfy: ComfyConfig = Field(default_factory=ComfyConfig) + + +class PyProjectSettings(BaseSettings): + project: dict = Field(default_factory=dict) + + tool: dict = Field(default_factory=dict) + + model_config = SettingsConfigDict(extra='allow') diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py new file mode 100644 index 000000000..25b21b1b8 --- /dev/null +++ b/comfy_extras/nodes_apg.py @@ -0,0 +1,76 @@ +import torch + +def project(v0, v1): + v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) + v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + return v0_parallel, v0_orthogonal + +class APG: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}), + "norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}), + "momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}), + } + } + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + CATEGORY = "sampling/custom_sampling" + + def patch(self, model, eta, norm_threshold, momentum): + running_avg = 0 + prev_sigma = None + + def pre_cfg_function(args): + nonlocal running_avg, prev_sigma + + if len(args["conds_out"]) == 1: return args["conds_out"] + + cond = args["conds_out"][0] + uncond = args["conds_out"][1] + sigma = args["sigma"][0] + cond_scale = args["cond_scale"] + + if prev_sigma is not None and sigma > prev_sigma: + running_avg = 0 + prev_sigma = sigma + + guidance = cond - uncond + + if momentum != 0: + if not torch.is_tensor(running_avg): + running_avg = guidance + else: + running_avg = momentum * running_avg + guidance + guidance = running_avg + + if norm_threshold > 0: + guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True) + scale = torch.minimum( + torch.ones_like(guidance_norm), + norm_threshold / guidance_norm + ) + guidance = guidance * scale + + guidance_parallel, guidance_orthogonal = project(guidance, cond) + modified_guidance = guidance_orthogonal + eta * guidance_parallel + + modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale + + return [modified_cond, uncond] + args["conds_out"][2:] + + m = model.clone() + m.set_model_sampler_pre_cfg_function(pre_cfg_function) + return (m,) + +NODE_CLASS_MAPPINGS = { + "APG": APG, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "APG": "Adaptive Projected Guidance", +} diff --git a/comfy_extras/nodes_camera_trajectory.py b/comfy_extras/nodes_camera_trajectory.py new file mode 100644 index 000000000..5e0e39f91 --- /dev/null +++ b/comfy_extras/nodes_camera_trajectory.py @@ -0,0 +1,218 @@ +import nodes +import torch +import numpy as np +from einops import rearrange +import comfy.model_management + + + +MAX_RESOLUTION = nodes.MAX_RESOLUTION + +CAMERA_DICT = { + "base_T_norm": 1.5, + "base_angle": np.pi/3, + "Static": { "angle":[0., 0., 0.], "T":[0., 0., 0.]}, + "Pan Up": { "angle":[0., 0., 0.], "T":[0., -1., 0.]}, + "Pan Down": { "angle":[0., 0., 0.], "T":[0.,1.,0.]}, + "Pan Left": { "angle":[0., 0., 0.], "T":[-1.,0.,0.]}, + "Pan Right": { "angle":[0., 0., 0.], "T": [1.,0.,0.]}, + "Zoom In": { "angle":[0., 0., 0.], "T": [0.,0.,2.]}, + "Zoom Out": { "angle":[0., 0., 0.], "T": [0.,0.,-2.]}, + "Anti Clockwise (ACW)": { "angle": [0., 0., -1.], "T":[0., 0., 0.]}, + "ClockWise (CW)": { "angle": [0., 0., 1.], "T":[0., 0., 0.]}, +} + + +def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'): + + def get_relative_pose(cam_params): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] + cam_to_origin = 0 + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, -cam_to_origin], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ abs_w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] + ret_poses = np.array(ret_poses, dtype=np.float32) + return ret_poses + + """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + cam_params = [Camera(cam_param) for cam_param in cam_params] + + sample_wh_ratio = width / height + pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed + + if pose_wh_ratio > sample_wh_ratio: + resized_ori_w = height * pose_wh_ratio + for cam_param in cam_params: + cam_param.fx = resized_ori_w * cam_param.fx / width + else: + resized_ori_h = width / pose_wh_ratio + for cam_param in cam_params: + cam_param.fy = resized_ori_h * cam_param.fy / height + + intrinsic = np.asarray([[cam_param.fx * width, + cam_param.fy * height, + cam_param.cx * width, + cam_param.cy * height] + for cam_param in cam_params], dtype=np.float32) + + K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] + c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere + c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] + plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W + plucker_embedding = plucker_embedding[None] + plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] + return plucker_embedding + +class Camera(object): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + def __init__(self, entry): + fx, fy, cx, cy = entry[1:5] + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + c2w_mat = np.array(entry[7:]).reshape(4, 4) + self.c2w_mat = c2w_mat + self.w2c_mat = np.linalg.inv(c2w_mat) + +def ray_condition(K, c2w, H, W, device): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + # c2w: B, V, 4, 4 + # K: B, V, 4 + + B = K.shape[0] + + j, i = torch.meshgrid( + torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), + torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), + indexing='ij' + ) + i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + + fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 + + zs = torch.ones_like(i) # [B, HxW] + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + zs = zs.expand_as(ys) + + directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 + directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 + + rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW + rays_o = c2w[..., :3, 3] # B, V, 3 + rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW + # c2w @ dirctions + rays_dxo = torch.cross(rays_o, rays_d) + plucker = torch.cat([rays_dxo, rays_d], dim=-1) + plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 + # plucker = plucker.permute(0, 1, 4, 2, 3) + return plucker + +def get_camera_motion(angle, T, speed, n=81): + def compute_R_form_rad_angle(angles): + theta_x, theta_y, theta_z = angles + Rx = np.array([[1, 0, 0], + [0, np.cos(theta_x), -np.sin(theta_x)], + [0, np.sin(theta_x), np.cos(theta_x)]]) + + Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)], + [0, 1, 0], + [-np.sin(theta_y), 0, np.cos(theta_y)]]) + + Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0], + [np.sin(theta_z), np.cos(theta_z), 0], + [0, 0, 1]]) + + R = np.dot(Rz, np.dot(Ry, Rx)) + return R + RT = [] + for i in range(n): + _angle = (i/n)*speed*(CAMERA_DICT["base_angle"])*angle + R = compute_R_form_rad_angle(_angle) + _T=(i/n)*speed*(CAMERA_DICT["base_T_norm"])*(T.reshape(3,1)) + _RT = np.concatenate([R,_T], axis=1) + RT.append(_RT) + RT = np.stack(RT) + return RT + +class WanCameraEmbedding: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}), + "width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}), + }, + "optional":{ + "speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}), + "fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}), + "fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}), + "cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}), + "cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}), + } + + } + + RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT") + RETURN_NAMES = ("camera_embedding","width","height","length") + FUNCTION = "run" + CATEGORY = "camera" + + def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5): + """ + Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021) + Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py + """ + motion_list = [camera_pose] + speed = speed + angle = np.array(CAMERA_DICT[motion_list[0]]["angle"]) + T = np.array(CAMERA_DICT[motion_list[0]]["T"]) + RT = get_camera_motion(angle, T, speed, length) + + trajs=[] + for cp in RT.tolist(): + traj=[fx,fy,cx,cy,0,0] + traj.extend(cp[0]) + traj.extend(cp[1]) + traj.extend(cp[2]) + traj.extend([0,0,0,1]) + trajs.append(traj) + + cam_params = np.array([[float(x) for x in pose] for pose in trajs]) + cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1) + control_camera_video = process_pose_params(cam_params, width=width, height=height) + control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device()) + + control_camera_video = torch.concat( + [ + torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), + control_camera_video[:, :, 1:] + ], dim=2 + ).transpose(1, 2) + + # Reshape, transpose, and view into desired shape + b, f, c, h, w = control_camera_video.shape + control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) + control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) + + return (control_camera_video, width, height, length) + + +NODE_CLASS_MAPPINGS = { + "WanCameraEmbedding": WanCameraEmbedding, +} diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py index 574262178..58c16f621 100644 --- a/comfy_extras/nodes_cond.py +++ b/comfy_extras/nodes_cond.py @@ -31,6 +31,7 @@ class T5TokenizerOptions: } } + CATEGORY = "_for_testing/conditioning" RETURN_TYPES = ("CLIP",) FUNCTION = "set_options" diff --git a/comfy_extras/nodes_cosmos.py b/comfy_extras/nodes_cosmos.py index bd35ddb06..4f4960551 100644 --- a/comfy_extras/nodes_cosmos.py +++ b/comfy_extras/nodes_cosmos.py @@ -2,6 +2,7 @@ import nodes import torch import comfy.model_management import comfy.utils +import comfy.latent_formats class EmptyCosmosLatentVideo: @@ -75,8 +76,53 @@ class CosmosImageToVideoLatent: out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) return (out_latent,) +class CosmosPredict2ImageToVideoLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": {"vae": ("VAE", ), + "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"start_image": ("IMAGE", ), + "end_image": ("IMAGE", ), + }} + + + RETURN_TYPES = ("LATENT",) + FUNCTION = "encode" + + CATEGORY = "conditioning/inpaint" + + def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None): + latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is None and end_image is None: + out_latent = {} + out_latent["samples"] = latent + return (out_latent,) + + mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) + + if start_image is not None: + latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1) + latent[:, :, :latent_temp.shape[-3]] = latent_temp + mask[:, :, :latent_temp.shape[-3]] *= 0.0 + + if end_image is not None: + latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0) + latent[:, :, -latent_temp.shape[-3]:] = latent_temp + mask[:, :, -latent_temp.shape[-3]:] *= 0.0 + + out_latent = {} + latent_format = comfy.latent_formats.Wan21() + latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask) + out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) + out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) + return (out_latent,) NODE_CLASS_MAPPINGS = { "EmptyCosmosLatentVideo": EmptyCosmosLatentVideo, "CosmosImageToVideoLatent": CosmosImageToVideoLatent, + "CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent, } diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 504010ad0..d7278e7a7 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -77,7 +77,7 @@ class HunyuanImageToVideo: "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "guidance_type": (["v1 (concat)", "v2 (replace)"], ) + "guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], ) }, "optional": {"start_image": ("IMAGE", ), }} @@ -101,10 +101,12 @@ class HunyuanImageToVideo: if guidance_type == "v1 (concat)": cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask} - else: + elif guidance_type == "v2 (replace)": cond = {'guiding_frame_index': 0} latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image out_latent["noise_mask"] = mask + elif guidance_type == "custom": + cond = {"ref_latent": concat_latent_image} positive = node_helpers.conditioning_set_values(positive, cond) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 77c305619..b1e0d4666 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -13,8 +13,11 @@ import os import re from io import BytesIO from inspect import cleandoc +import torch +import comfy.utils -from comfy.comfy_types import FileLocator +from comfy.comfy_types import FileLocator, IO +from server import PromptServer MAX_RESOLUTION = nodes.MAX_RESOLUTION @@ -74,6 +77,24 @@ class ImageFromBatch: s = s_in[batch_index:batch_index + length].clone() return (s,) + +class ImageAddNoise: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}), + "strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "repeat" + + CATEGORY = "image" + + def repeat(self, image, seed, strength): + generator = torch.manual_seed(seed) + s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0) + return (s,) + class SaveAnimatedWEBP: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -210,6 +231,186 @@ class SVG: all_svgs_list.extend(svg_item.data) return SVG(all_svgs_list) + +class ImageStitch: + """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image1": ("IMAGE",), + "direction": (["right", "down", "left", "up"], {"default": "right"}), + "match_image_size": ("BOOLEAN", {"default": True}), + "spacing_width": ( + "INT", + {"default": 0, "min": 0, "max": 1024, "step": 2}, + ), + "spacing_color": ( + ["white", "black", "red", "green", "blue"], + {"default": "white"}, + ), + }, + "optional": { + "image2": ("IMAGE",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "stitch" + CATEGORY = "image/transform" + DESCRIPTION = """ +Stitches image2 to image1 in the specified direction. +If image2 is not provided, returns image1 unchanged. +Optional spacing can be added between images. +""" + + def stitch( + self, + image1, + direction, + match_image_size, + spacing_width, + spacing_color, + image2=None, + ): + if image2 is None: + return (image1,) + + # Handle batch size differences + if image1.shape[0] != image2.shape[0]: + max_batch = max(image1.shape[0], image2.shape[0]) + if image1.shape[0] < max_batch: + image1 = torch.cat( + [image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)] + ) + if image2.shape[0] < max_batch: + image2 = torch.cat( + [image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)] + ) + + # Match image sizes if requested + if match_image_size: + h1, w1 = image1.shape[1:3] + h2, w2 = image2.shape[1:3] + aspect_ratio = w2 / h2 + + if direction in ["left", "right"]: + target_h, target_w = h1, int(h1 * aspect_ratio) + else: # up, down + target_w, target_h = w1, int(w1 / aspect_ratio) + + image2 = comfy.utils.common_upscale( + image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled" + ).movedim(1, -1) + + # When not matching sizes, pad to align non-concat dimensions + if not match_image_size: + h1, w1 = image1.shape[1:3] + h2, w2 = image2.shape[1:3] + + if direction in ["left", "right"]: + # For horizontal concat, pad heights to match + if h1 != h2: + target_h = max(h1, h2) + if h1 < target_h: + pad_h = target_h - h1 + pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 + image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0) + if h2 < target_h: + pad_h = target_h - h2 + pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 + image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0) + else: # up, down + # For vertical concat, pad widths to match + if w1 != w2: + target_w = max(w1, w2) + if w1 < target_w: + pad_w = target_w - w1 + pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 + image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0) + if w2 < target_w: + pad_w = target_w - w2 + pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 + image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0) + + # Ensure same number of channels + if image1.shape[-1] != image2.shape[-1]: + max_channels = max(image1.shape[-1], image2.shape[-1]) + if image1.shape[-1] < max_channels: + image1 = torch.cat( + [ + image1, + torch.ones( + *image1.shape[:-1], + max_channels - image1.shape[-1], + device=image1.device, + ), + ], + dim=-1, + ) + if image2.shape[-1] < max_channels: + image2 = torch.cat( + [ + image2, + torch.ones( + *image2.shape[:-1], + max_channels - image2.shape[-1], + device=image2.device, + ), + ], + dim=-1, + ) + + # Add spacing if specified + if spacing_width > 0: + spacing_width = spacing_width + (spacing_width % 2) # Ensure even + + color_map = { + "white": 1.0, + "black": 0.0, + "red": (1.0, 0.0, 0.0), + "green": (0.0, 1.0, 0.0), + "blue": (0.0, 0.0, 1.0), + } + color_val = color_map[spacing_color] + + if direction in ["left", "right"]: + spacing_shape = ( + image1.shape[0], + max(image1.shape[1], image2.shape[1]), + spacing_width, + image1.shape[-1], + ) + else: + spacing_shape = ( + image1.shape[0], + spacing_width, + max(image1.shape[2], image2.shape[2]), + image1.shape[-1], + ) + + spacing = torch.full(spacing_shape, 0.0, device=image1.device) + if isinstance(color_val, tuple): + for i, c in enumerate(color_val): + if i < spacing.shape[-1]: + spacing[..., i] = c + if spacing.shape[-1] == 4: # Add alpha + spacing[..., 3] = 1.0 + else: + spacing[..., : min(3, spacing.shape[-1])] = color_val + if spacing.shape[-1] == 4: + spacing[..., 3] = 1.0 + + # Concatenate images + images = [image2, image1] if direction in ["left", "up"] else [image1, image2] + if spacing_width > 0: + images.insert(1, spacing) + + concat_dim = 2 if direction in ["left", "right"] else 1 + return (torch.cat(images, dim=concat_dim),) + + class SaveSVGNode: """ Save SVG files on disk. @@ -291,11 +492,45 @@ class SaveSVGNode: counter += 1 return { "ui": { "images": results } } +class GetImageSize: + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": (IO.IMAGE,), + }, + "hidden": { + "unique_id": "UNIQUE_ID", + } + } + + RETURN_TYPES = (IO.INT, IO.INT, IO.INT) + RETURN_NAMES = ("width", "height", "batch_size") + FUNCTION = "get_size" + + CATEGORY = "image" + DESCRIPTION = """Returns width and height of the image, and passes it through unchanged.""" + + def get_size(self, image, unique_id=None) -> tuple[int, int]: + height = image.shape[1] + width = image.shape[2] + batch_size = image.shape[0] + + # Send progress text to display size on the node + if unique_id: + PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id) + + return width, height, batch_size + NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, "RepeatImageBatch": RepeatImageBatch, "ImageFromBatch": ImageFromBatch, + "ImageAddNoise": ImageAddNoise, "SaveAnimatedWEBP": SaveAnimatedWEBP, "SaveAnimatedPNG": SaveAnimatedPNG, "SaveSVGNode": SaveSVGNode, + "ImageStitch": ImageStitch, + "GetImageSize": GetImageSize, } diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index d5b4d9111..40d03e18a 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -16,7 +16,7 @@ class Load3D(): os.makedirs(input_dir, exist_ok=True) - files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.mtl', '.fbx', '.stl'))] + files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.fbx', '.stl'))] return {"required": { "model_file": (sorted(files), {"file_upload": True}), diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 71a652ffa..ae5d2c563 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -189,7 +189,7 @@ class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps"],), + "sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps", "cosmos_rflow"],), "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), }} @@ -202,6 +202,7 @@ class ModelSamplingContinuousEDM: def patch(self, model, sampling, sigma_max, sigma_min): m = model.clone() + sampling_base = comfy.model_sampling.ModelSamplingContinuousEDM latent_format = None sigma_data = 1.0 if sampling == "eps": @@ -215,8 +216,11 @@ class ModelSamplingContinuousEDM: sampling_type = comfy.model_sampling.EDM sigma_data = 0.5 latent_format = comfy.latent_formats.SDXL_Playground_2_5() + elif sampling == "cosmos_rflow": + sampling_type = comfy.model_sampling.COSMOS_RFLOW + sampling_base = comfy.model_sampling.ModelSamplingCosmosRFlow - class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): + class ModelSamplingAdvanced(sampling_base, sampling_type): pass model_sampling = ModelSamplingAdvanced(model.model.model_config) diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index a852326e5..b1a8ceef0 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -8,7 +8,8 @@ class StringConcatenate(): return { "required": { "string_a": (IO.STRING, {"multiline": True}), - "string_b": (IO.STRING, {"multiline": True}) + "string_b": (IO.STRING, {"multiline": True}), + "delimiter": (IO.STRING, {"multiline": False, "default": ""}) } } @@ -16,8 +17,8 @@ class StringConcatenate(): FUNCTION = "execute" CATEGORY = "utils/string" - def execute(self, string_a, string_b, **kwargs): - return string_a + string_b, + def execute(self, string_a, string_b, delimiter, **kwargs): + return delimiter.join((string_a, string_b)), class StringSubstring(): @classmethod @@ -295,6 +296,41 @@ class RegexExtract(): return result, + +class RegexReplace(): + DESCRIPTION = "Find and replace text using regex patterns." + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "string": (IO.STRING, {"multiline": True}), + "regex_pattern": (IO.STRING, {"multiline": True}), + "replace": (IO.STRING, {"multiline": True}), + }, + "optional": { + "case_insensitive": (IO.BOOLEAN, {"default": True}), + "multiline": (IO.BOOLEAN, {"default": False}), + "dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}), + "count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}), + } + } + + RETURN_TYPES = (IO.STRING,) + FUNCTION = "execute" + CATEGORY = "utils/string" + + def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs): + flags = 0 + + if case_insensitive: + flags |= re.IGNORECASE + if multiline: + flags |= re.MULTILINE + if dotall: + flags |= re.DOTALL + result = re.sub(regex_pattern, replace, string, count=count, flags=flags) + return result, + NODE_CLASS_MAPPINGS = { "StringConcatenate": StringConcatenate, "StringSubstring": StringSubstring, @@ -305,7 +341,8 @@ NODE_CLASS_MAPPINGS = { "StringContains": StringContains, "StringCompare": StringCompare, "RegexMatch": RegexMatch, - "RegexExtract": RegexExtract + "RegexExtract": RegexExtract, + "RegexReplace": RegexReplace, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -318,5 +355,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "StringContains": "Contains", "StringCompare": "Compare", "RegexMatch": "Regex Match", - "RegexExtract": "Regex Extract" + "RegexExtract": "Regex Extract", + "RegexReplace": "Regex Replace", } diff --git a/comfy_extras/nodes_torch_compile.py b/comfy_extras/nodes_torch_compile.py index 1fe6f42c7..605536678 100644 --- a/comfy_extras/nodes_torch_compile.py +++ b/comfy_extras/nodes_torch_compile.py @@ -1,4 +1,5 @@ -import torch +from comfy_api.torch_helpers import set_torch_compile_wrapper + class TorchCompileModel: @classmethod @@ -14,7 +15,7 @@ class TorchCompileModel: def patch(self, model, backend): m = model.clone() - m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend)) + set_torch_compile_wrapper(model=m, backend=backend) return (m, ) NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py new file mode 100644 index 000000000..fbff01010 --- /dev/null +++ b/comfy_extras/nodes_train.py @@ -0,0 +1,709 @@ +import datetime +import json +import logging +import os + +import numpy as np +import safetensors +import torch +from PIL import Image, ImageDraw, ImageFont +from PIL.PngImagePlugin import PngInfo +import torch.utils.checkpoint +import tqdm + +import comfy.samplers +import comfy.sd +import comfy.utils +import comfy.model_management +import comfy_extras.nodes_custom_sampler +import folder_paths +import node_helpers +from comfy.cli_args import args +from comfy.comfy_types.node_typing import IO +from comfy.weight_adapter import adapters + + +class TrainSampler(comfy.samplers.Sampler): + + def __init__(self, loss_fn, optimizer, loss_callback=None): + self.loss_fn = loss_fn + self.optimizer = optimizer + self.loss_callback = loss_callback + + def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + self.optimizer.zero_grad() + noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False) + latent = model_wrap.inner_model.model_sampling.noise_scaling( + torch.zeros_like(sigmas), + torch.zeros_like(noise, requires_grad=True), + latent_image, + False + ) + + # Ensure model is in training mode and computing gradients + # x0 pred + denoised = model_wrap(noise, sigmas, **extra_args) + try: + loss = self.loss_fn(denoised, latent.clone()) + except RuntimeError as e: + if "does not require grad and does not have a grad_fn" in str(e): + logging.info("WARNING: This is likely due to the model is loaded in inference mode.") + loss.backward() + if self.loss_callback: + self.loss_callback(loss.item()) + + self.optimizer.step() + # torch.cuda.memory._dump_snapshot("trainn.pickle") + # torch.cuda.memory._record_memory_history(enabled=None) + return torch.zeros_like(latent_image) + + +class BiasDiff(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.bias = bias + + def __call__(self, b): + org_dtype = b.dtype + return (b.to(self.bias) + self.bias).to(org_dtype) + + def passive_memory_usage(self): + return self.bias.nelement() * self.bias.element_size() + + def move_to(self, device): + self.to(device=device) + return self.passive_memory_usage() + + +def load_and_process_images(image_files, input_dir, resize_method="None"): + """Utility function to load and process a list of images. + + Args: + image_files: List of image filenames + input_dir: Base directory containing the images + resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") + + Returns: + torch.Tensor: Batch of processed images + """ + if not image_files: + raise ValueError("No valid images found in input") + + output_images = [] + w, h = None, None + + for file in image_files: + image_path = os.path.join(input_dir, file) + img = node_helpers.pillow(Image.open, image_path) + + if img.mode == "I": + img = img.point(lambda i: i * (1 / 255)) + img = img.convert("RGB") + + if w is None and h is None: + w, h = img.size[0], img.size[1] + + # Resize image to first image + if img.size[0] != w or img.size[1] != h: + if resize_method == "Stretch": + img = img.resize((w, h), Image.Resampling.LANCZOS) + elif resize_method == "Crop": + img = img.crop((0, 0, w, h)) + elif resize_method == "Pad": + img = img.resize((w, h), Image.Resampling.LANCZOS) + elif resize_method == "None": + raise ValueError( + "Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images." + ) + + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] + output_images.append(img_tensor) + + return torch.cat(output_images, dim=0) + + +class LoadImageSetNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "images": ( + [ + f + for f in os.listdir(folder_paths.get_input_directory()) + if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff")) + ], + {"image_upload": True, "allow_batch": True}, + ) + }, + "optional": { + "resize_method": ( + ["None", "Stretch", "Crop", "Pad"], + {"default": "None"}, + ), + }, + } + + INPUT_IS_LIST = True + RETURN_TYPES = ("IMAGE",) + FUNCTION = "load_images" + CATEGORY = "loaders" + EXPERIMENTAL = True + DESCRIPTION = "Loads a batch of images from a directory for training." + + @classmethod + def VALIDATE_INPUTS(s, images, resize_method): + filenames = images[0] if isinstance(images[0], list) else images + + for image in filenames: + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) + return True + + def load_images(self, input_files, resize_method): + input_dir = folder_paths.get_input_directory() + valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"] + image_files = [ + f + for f in input_files + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, input_dir, resize_method) + return (output_tensor,) + + +class LoadImageSetFromFolderNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}) + }, + "optional": { + "resize_method": ( + ["None", "Stretch", "Crop", "Pad"], + {"default": "None"}, + ), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "load_images" + CATEGORY = "loaders" + EXPERIMENTAL = True + DESCRIPTION = "Loads a batch of images from a directory for training." + + def load_images(self, folder, resize_method): + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + image_files = [ + f + for f in os.listdir(sub_input_dir) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method) + return (output_tensor,) + + +def draw_loss_graph(loss_map, steps): + width, height = 500, 300 + img = Image.new("RGB", (width, height), "white") + draw = ImageDraw.Draw(img) + + min_loss, max_loss = min(loss_map.values()), max(loss_map.values()) + scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()] + + prev_point = (0, height - int(scaled_loss[0] * height)) + for i, l in enumerate(scaled_loss[1:], start=1): + x = int(i / (steps - 1) * width) + y = height - int(l * height) + draw.line([prev_point, (x, y)], fill="blue", width=2) + prev_point = (x, y) + + return img + + +def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None): + if result is None: + result = [] + elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): + result.append(model) + logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") + return result + name = name or "root" + for next_name, child in model.named_children(): + find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}") + return result + + +def patch(m): + if not hasattr(m, "forward"): + return + org_forward = m.forward + def fwd(args, kwargs): + return org_forward(*args, **kwargs) + def checkpointing_fwd(*args, **kwargs): + return torch.utils.checkpoint.checkpoint( + fwd, args, kwargs, use_reentrant=False + ) + m.org_forward = org_forward + m.forward = checkpointing_fwd + + +def unpatch(m): + if hasattr(m, "org_forward"): + m.forward = m.org_forward + del m.org_forward + + +class TrainLoraNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}), + "latents": ( + "LATENT", + { + "tooltip": "The Latents to use for training, serve as dataset/input of the model." + }, + ), + "positive": ( + IO.CONDITIONING, + {"tooltip": "The positive conditioning to use for training."}, + ), + "batch_size": ( + IO.INT, + { + "default": 1, + "min": 1, + "max": 10000, + "step": 1, + "tooltip": "The batch size to use for training.", + }, + ), + "steps": ( + IO.INT, + { + "default": 16, + "min": 1, + "max": 100000, + "tooltip": "The number of steps to train the LoRA for.", + }, + ), + "learning_rate": ( + IO.FLOAT, + { + "default": 0.0005, + "min": 0.0000001, + "max": 1.0, + "step": 0.000001, + "tooltip": "The learning rate to use for training.", + }, + ), + "rank": ( + IO.INT, + { + "default": 8, + "min": 1, + "max": 128, + "tooltip": "The rank of the LoRA layers.", + }, + ), + "optimizer": ( + ["AdamW", "Adam", "SGD", "RMSprop"], + { + "default": "AdamW", + "tooltip": "The optimizer to use for training.", + }, + ), + "loss_function": ( + ["MSE", "L1", "Huber", "SmoothL1"], + { + "default": "MSE", + "tooltip": "The loss function to use for training.", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", + }, + ), + "training_dtype": ( + ["bf16", "fp32"], + {"default": "bf16", "tooltip": "The dtype to use for training."}, + ), + "lora_dtype": ( + ["bf16", "fp32"], + {"default": "bf16", "tooltip": "The dtype to use for lora."}, + ), + "existing_lora": ( + folder_paths.get_filename_list("loras") + ["[None]"], + { + "default": "[None]", + "tooltip": "The existing LoRA to append to. Set to None for new LoRA.", + }, + ), + }, + } + + RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT) + RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps") + FUNCTION = "train" + CATEGORY = "training" + EXPERIMENTAL = True + + def train( + self, + model, + latents, + positive, + batch_size, + steps, + learning_rate, + rank, + optimizer, + loss_function, + seed, + training_dtype, + lora_dtype, + existing_lora, + ): + mp = model.clone() + dtype = node_helpers.string_to_torch_dtype(training_dtype) + lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) + mp.set_model_compute_dtype(dtype) + + latents = latents["samples"].to(dtype) + num_images = latents.shape[0] + + with torch.inference_mode(False): + lora_sd = {} + generator = torch.Generator() + generator.manual_seed(seed) + + # Load existing LoRA weights if provided + existing_weights = {} + existing_steps = 0 + if existing_lora != "[None]": + lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora) + # Extract steps from filename like "trained_lora_10_steps_20250225_203716" + existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1]) + if lora_path: + existing_weights = comfy.utils.load_torch_file(lora_path) + + all_weight_adapters = [] + for n, m in mp.model.named_modules(): + if hasattr(m, "weight_function"): + if m.weight is not None: + key = "{}.weight".format(n) + shape = m.weight.shape + if len(shape) >= 2: + alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) + dora_scale = existing_weights.get( + f"{key}.dora_scale", None + ) + for adapter_cls in adapters: + existing_adapter = adapter_cls.load( + n, existing_weights, alpha, dora_scale + ) + if existing_adapter is not None: + break + else: + # If no existing adapter found, use LoRA + # We will add algo option in the future + existing_adapter = None + adapter_cls = adapters[0] + + if existing_adapter is not None: + train_adapter = existing_adapter.to_train().to(lora_dtype) + else: + # Use LoRA with alpha=1.0 by default + train_adapter = adapter_cls.create_train( + m.weight, rank=rank, alpha=1.0 + ).to(lora_dtype) + for name, parameter in train_adapter.named_parameters(): + lora_sd[f"{n}.{name}"] = parameter + + mp.add_weight_wrapper(key, train_adapter) + all_weight_adapters.append(train_adapter) + else: + diff = torch.nn.Parameter( + torch.zeros( + m.weight.shape, dtype=lora_dtype, requires_grad=True + ) + ) + diff_module = BiasDiff(diff) + mp.add_weight_wrapper(key, BiasDiff(diff)) + all_weight_adapters.append(diff_module) + lora_sd["{}.diff".format(n)] = diff + if hasattr(m, "bias") and m.bias is not None: + key = "{}.bias".format(n) + bias = torch.nn.Parameter( + torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True) + ) + bias_module = BiasDiff(bias) + lora_sd["{}.diff_b".format(n)] = bias + mp.add_weight_wrapper(key, BiasDiff(bias)) + all_weight_adapters.append(bias_module) + + if optimizer == "Adam": + optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate) + elif optimizer == "AdamW": + optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate) + elif optimizer == "SGD": + optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate) + elif optimizer == "RMSprop": + optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate) + + # Setup loss function based on selection + if loss_function == "MSE": + criterion = torch.nn.MSELoss() + elif loss_function == "L1": + criterion = torch.nn.L1Loss() + elif loss_function == "Huber": + criterion = torch.nn.HuberLoss() + elif loss_function == "SmoothL1": + criterion = torch.nn.SmoothL1Loss() + + # setup models + for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): + patch(m) + comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) + + # Setup sampler and guider like in test script + loss_map = {"loss": []} + def loss_callback(loss): + loss_map["loss"].append(loss) + pbar.set_postfix({"loss": f"{loss:.4f}"}) + train_sampler = TrainSampler( + criterion, optimizer, loss_callback=loss_callback + ) + guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) + guider.set_conds(positive) # Set conditioning from input + ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced() + + # yoland: this currently resize to the first image in the dataset + + # Training loop + torch.cuda.empty_cache() + try: + for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): + # Generate random sigma + sigma = mp.model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + sigma = torch.tensor([sigma]) + + noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed) + + indices = torch.randperm(num_images)[:batch_size] + ss.sample( + noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()} + ) + finally: + for m in mp.model.modules(): + unpatch(m) + del ss, train_sampler, optimizer + torch.cuda.empty_cache() + + for adapter in all_weight_adapters: + adapter.requires_grad_(False) + + for param in lora_sd: + lora_sd[param] = lora_sd[param].to(lora_dtype) + + return (mp, lora_sd, loss_map, steps + existing_steps) + + +class LoraModelLoader: + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), + "lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), + } + } + + RETURN_TYPES = ("MODEL",) + OUTPUT_TOOLTIPS = ("The modified diffusion model.",) + FUNCTION = "load_lora_model" + + CATEGORY = "loaders" + DESCRIPTION = "Load Trained LoRA weights from Train LoRA node." + EXPERIMENTAL = True + + def load_lora_model(self, model, lora, strength_model): + if strength_model == 0: + return (model, ) + + model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) + return (model_lora, ) + + +class SaveLoRA: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "lora": ( + IO.LORA_MODEL, + { + "tooltip": "The LoRA model to save. Do not use the model with LoRA layers." + }, + ), + "prefix": ( + "STRING", + { + "default": "loras/ComfyUI_trained_lora", + "tooltip": "The prefix to use for the saved LoRA file.", + }, + ), + }, + "optional": { + "steps": ( + IO.INT, + { + "forceInput": True, + "tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.", + }, + ), + }, + } + + RETURN_TYPES = () + FUNCTION = "save" + CATEGORY = "loaders" + EXPERIMENTAL = True + OUTPUT_NODE = True + + def save(self, lora, prefix, steps=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir) + if steps is None: + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + else: + output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + safetensors.torch.save_file(lora, output_checkpoint) + return {} + + +class LossGraphNode: + def __init__(self): + self.output_dir = folder_paths.get_temp_directory() + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "loss": (IO.LOSS_MAP, {"default": {}}), + "filename_prefix": (IO.STRING, {"default": "loss_graph"}), + }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "plot_loss" + OUTPUT_NODE = True + CATEGORY = "training" + EXPERIMENTAL = True + DESCRIPTION = "Plots the loss graph and saves it to the output directory." + + def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None): + loss_values = loss["loss"] + width, height = 800, 480 + margin = 40 + + img = Image.new( + "RGB", (width + margin, height + margin), "white" + ) # Extend canvas + draw = ImageDraw.Draw(img) + + min_loss, max_loss = min(loss_values), max(loss_values) + scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values] + + steps = len(loss_values) + + prev_point = (margin, height - int(scaled_loss[0] * height)) + for i, l in enumerate(scaled_loss[1:], start=1): + x = margin + int(i / steps * width) # Scale X properly + y = height - int(l * height) + draw.line([prev_point, (x, y)], fill="blue", width=2) + prev_point = (x, y) + + draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis + draw.line( + [(margin, height), (width + margin, height)], fill="black", width=2 + ) # X-axis + + font = None + try: + font = ImageFont.truetype("arial.ttf", 12) + except IOError: + font = ImageFont.load_default() + + # Add axis labels + draw.text((5, height // 2), "Loss", font=font, fill="black") + draw.text((width // 2, height + 10), "Steps", font=font, fill="black") + + # Add min/max loss values + draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black") + draw.text( + (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" + ) + + metadata = None + if not args.disable_metadata: + metadata = PngInfo() + if prompt is not None: + metadata.add_text("prompt", json.dumps(prompt)) + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata.add_text(x, json.dumps(extra_pnginfo[x])) + + date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + img.save( + os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"), + pnginfo=metadata, + ) + return { + "ui": { + "images": [ + { + "filename": f"{filename_prefix}_{date}.png", + "subfolder": "", + "type": "temp", + } + ] + } + } + + +NODE_CLASS_MAPPINGS = { + "TrainLoraNode": TrainLoraNode, + "SaveLoRANode": SaveLoRA, + "LoraModelLoader": LoraModelLoader, + "LoadImageSetFromFolderNode": LoadImageSetFromFolderNode, + "LossGraphNode": LossGraphNode, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "TrainLoraNode": "Train LoRA", + "SaveLoRANode": "Save LoRA Weights", + "LoraModelLoader": "Load LoRA Model", + "LoadImageSetFromFolderNode": "Load Image Dataset from Folder", + "LossGraphNode": "Plot Loss Graph", +} diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 9dda64597..d6097a104 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -268,8 +268,9 @@ class WanVaceToVideo: trim_latent = reference_image.shape[2] mask = mask.unsqueeze(0) - positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength}) - negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength}) + + positive = node_helpers.conditioning_set_values(positive, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True) latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device()) out_latent = {} @@ -297,6 +298,90 @@ class TrimVideoLatent: samples_out["samples"] = s1[:, :, trim_amount:] return (samples_out,) +class WanCameraImageToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "camera_conditions": ("WAN_CAMERA_EMBEDDING", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(start_image[:, :, :, :3]) + concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) + + if camera_conditions is not None: + positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions}) + negative = node_helpers.conditioning_set_values(negative, {'camera_conditions': camera_conditions}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + +class WanPhantomSubjectToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"images": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, images): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + cond2 = negative + if images is not None: + images = comfy.utils.common_upscale(images[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + latent_images = [] + for i in images: + latent_images += [vae.encode(i.unsqueeze(0)[:, :, :, :3])] + concat_latent_image = torch.cat(latent_images, dim=2) + + positive = node_helpers.conditioning_set_values(positive, {"time_dim_concat": concat_latent_image}) + cond2 = node_helpers.conditioning_set_values(negative, {"time_dim_concat": concat_latent_image}) + negative = node_helpers.conditioning_set_values(negative, {"time_dim_concat": comfy.latent_formats.Wan21().process_out(torch.zeros_like(concat_latent_image))}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, cond2, negative, out_latent) NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, @@ -305,4 +390,6 @@ NODE_CLASS_MAPPINGS = { "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, "WanVaceToVideo": WanVaceToVideo, "TrimVideoLatent": TrimVideoLatent, + "WanCameraImageToVideo": WanCameraImageToVideo, + "WanPhantomSubjectToVideo": WanPhantomSubjectToVideo, } diff --git a/comfy_extras/nodes_webcam.py b/comfy_extras/nodes_webcam.py index 062b15cf8..5bf80b4c6 100644 --- a/comfy_extras/nodes_webcam.py +++ b/comfy_extras/nodes_webcam.py @@ -23,6 +23,10 @@ class WebcamCapture(nodes.LoadImage): def load_capture(self, image, **kwargs): return super().load_image(folder_paths.get_annotated_filepath(image)) + @classmethod + def IS_CHANGED(cls, image, width, height, capture_on_queue): + return super().IS_CHANGED(image) + NODE_CLASS_MAPPINGS = { "WebcamCapture": WebcamCapture, diff --git a/comfyui_version.py b/comfyui_version.py index b740b378d..fedd3466f 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.34" +__version__ = "0.3.41" diff --git a/execution.py b/execution.py index e5d1c69d9..d0012afda 100644 --- a/execution.py +++ b/execution.py @@ -1,23 +1,35 @@ -import sys import copy -import logging -import threading import heapq +import inspect +import logging +import sys +import threading import time import traceback from enum import Enum -import inspect from typing import List, Literal, NamedTuple, Optional import torch -import nodes import comfy.model_management -from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker -from comfy_execution.graph_utils import is_link, GraphBuilder -from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID +import nodes +from comfy_execution.caching import ( + CacheKeySetID, + CacheKeySetInputSignature, + DependencyAwareCache, + HierarchicalCache, + LRUCache, +) +from comfy_execution.graph import ( + DynamicPrompt, + ExecutionBlocker, + ExecutionList, + get_input_info, +) +from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.validation import validate_node_input + class ExecutionResult(Enum): SUCCESS = 0 FAILURE = 1 @@ -909,7 +921,6 @@ class PromptQueue: self.currently_running = {} self.history = {} self.flags = {} - server.prompt_queue = self def put(self, item): with self.mutex: @@ -954,6 +965,7 @@ class PromptQueue: self.history[prompt[1]].update(history_result) self.server.queue_updated() + # Note: slow def get_current_queue(self): with self.mutex: out = [] @@ -961,6 +973,13 @@ class PromptQueue: out += [x] return (out, copy.deepcopy(self.queue)) + # read-safe as long as queue items are immutable + def get_current_queue_volatile(self): + with self.mutex: + running = [x for x in self.currently_running.values()] + queued = copy.copy(self.queue) + return (running, queued) + def get_tasks_remaining(self): with self.mutex: return len(self.queue) + len(self.currently_running) diff --git a/fix_torch.py b/fix_torch.py deleted file mode 100644 index ce117b639..000000000 --- a/fix_torch.py +++ /dev/null @@ -1,28 +0,0 @@ -import importlib.util -import shutil -import os -import ctypes -import logging - - -def fix_pytorch_libomp(): - """ - Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed. - """ - torch_spec = importlib.util.find_spec("torch") - for folder in torch_spec.submodule_search_locations: - lib_folder = os.path.join(folder, "lib") - test_file = os.path.join(lib_folder, "fbgemm.dll") - dest = os.path.join(lib_folder, "libomp140.x86_64.dll") - if os.path.exists(dest): - break - - with open(test_file, "rb") as f: - contents = f.read() - if b"libomp140.x86_64.dll" not in contents: - break - try: - ctypes.cdll.LoadLibrary(test_file) - except FileNotFoundError: - logging.warning("Detected pytorch version with libomp issue, patching.") - shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest) diff --git a/folder_paths.py b/folder_paths.py index f0b3fd103..9ec952940 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -276,6 +276,9 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str]) def get_full_path(folder_name: str, filename: str) -> str | None: + """ + Get the full path of a file in a folder, has to be a file + """ global folder_names_and_paths folder_name = map_legacy(folder_name) if folder_name not in folder_names_and_paths: @@ -293,6 +296,9 @@ def get_full_path(folder_name: str, filename: str) -> str | None: def get_full_path_or_raise(folder_name: str, filename: str) -> str: + """ + Get the full path of a file in a folder, has to be a file + """ full_path = get_full_path(folder_name, filename) if full_path is None: raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.") @@ -394,3 +400,26 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im os.makedirs(full_output_folder, exist_ok=True) counter = 1 return full_output_folder, filename, counter, subfolder, filename_prefix + +def get_input_subfolders() -> list[str]: + """Returns a list of all subfolder paths in the input directory, recursively. + + Returns: + List of folder paths relative to the input directory, excluding the root directory + """ + input_dir = get_input_directory() + folders = [] + + try: + if not os.path.exists(input_dir): + return [] + + for root, dirs, _ in os.walk(input_dir): + rel_path = os.path.relpath(root, input_dir) + if rel_path != ".": # Only include non-root directories + # Normalize path separators to forward slashes + folders.append(rel_path.replace(os.sep, '/')) + + return sorted(folders) + except FileNotFoundError: + return [] diff --git a/main.py b/main.py index 221e48e41..c8c4194d4 100644 --- a/main.py +++ b/main.py @@ -17,7 +17,6 @@ if __name__ == "__main__": os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['DO_NOT_TRACK'] = '1' - setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) def apply_custom_paths(): @@ -125,13 +124,6 @@ if __name__ == "__main__": import cuda_malloc -if args.windows_standalone_build: - try: - from fix_torch import fix_pytorch_libomp - fix_pytorch_libomp() - except: - pass - import comfy.utils import execution @@ -245,6 +237,15 @@ def cleanup_temp(): shutil.rmtree(temp_dir, ignore_errors=True) +def setup_database(): + try: + from app.database.db import init_db, dependencies_available + if dependencies_available(): + init_db() + except Exception as e: + logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") + + def start_comfyui(asyncio_loop=None): """ Starts the ComfyUI server using the provided asyncio event loop or creates a new one. @@ -267,18 +268,18 @@ def start_comfyui(asyncio_loop=None): asyncio_loop = asyncio.new_event_loop() asyncio.set_event_loop(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop) - q = execution.PromptQueue(prompt_server) hook_breaker_ac10a0.save_functions() nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes) hook_breaker_ac10a0.restore_functions() cuda_malloc_warning() + setup_database() prompt_server.add_routes() hijack_progress(prompt_server) - threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start() + threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start() if args.quick_test_for_ci: exit(0) @@ -308,6 +309,9 @@ if __name__ == "__main__": logging.info("Python version: {}".format(sys.version)) logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) + if sys.version_info.major == 3 and sys.version_info.minor < 10: + logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") + event_loop, _, start_all_func = start_comfyui() try: x = start_all_func() diff --git a/node_helpers.py b/node_helpers.py index c3e1a14ca..4ff960ef8 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -5,12 +5,18 @@ from comfy.cli_args import args from PIL import ImageFile, UnidentifiedImageError -def conditioning_set_values(conditioning, values={}): +def conditioning_set_values(conditioning, values={}, append=False): c = [] for t in conditioning: n = [t[0], t[1].copy()] for k in values: - n[1][k] = values[k] + val = values[k] + if append: + old_val = n[1].get(k, None) + if old_val is not None: + val = old_val + val + + n[1][k] = val c.append(n) return c diff --git a/nodes.py b/nodes.py index 13d176a03..c8d0cacb5 100644 --- a/nodes.py +++ b/nodes.py @@ -1103,16 +1103,7 @@ class unCLIPConditioning: if strength == 0: return (conditioning, ) - c = [] - for t in conditioning: - o = t[1].copy() - x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation} - if "unclip_conditioning" in o: - o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x] - else: - o["unclip_conditioning"] = [x] - n = [t[0], o] - c.append(n) + c = node_helpers.conditioning_set_values(conditioning, {"unclip_conditioning": [{"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}]}, append=True) return (c, ) class GLIGENLoader: @@ -1940,7 +1931,7 @@ class ImagePadForOutpaint: mask[top:top + d2, left:left + d3] = t - return (new_image, mask) + return (new_image, mask.unsqueeze(0)) NODE_CLASS_MAPPINGS = { @@ -2070,11 +2061,13 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ImagePadForOutpaint": "Pad Image for Outpainting", "ImageBatch": "Batch Images", "ImageCrop": "Image Crop", + "ImageStitch": "Image Stitch", "ImageBlend": "Image Blend", "ImageBlur": "Image Blur", "ImageQuantize": "Image Quantize", "ImageSharpen": "Image Sharpen", "ImageScaleToTotalPixels": "Scale Image to Total Pixels", + "GetImageSize": "Get Image Size", # _for_testing "VAEDecodeTiled": "VAE Decode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)", @@ -2132,6 +2125,25 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir) + try: + from comfy_config import config_parser + + project_config = config_parser.extract_node_configuration(module_path) + + web_dir_name = project_config.tool_comfy.web + + if web_dir_name: + web_dir_path = os.path.join(module_path, web_dir_name) + + if os.path.isdir(web_dir_path): + project_name = project_config.project.name + + EXTENSION_WEB_DIRS[project_name] = web_dir_path + + logging.info("Automatically register web folder {} for {}".format(web_dir_name, project_name)) + except Exception as e: + logging.warning(f"Unable to parse pyproject.toml due to lack dependency pydantic-settings, please run 'pip install -r requirements.txt': {e}") + if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None: web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY"))) if os.path.isdir(web_dir): @@ -2219,6 +2231,7 @@ def init_builtin_extra_nodes(): "nodes_model_downscale.py", "nodes_images.py", "nodes_video_model.py", + "nodes_train.py", "nodes_sag.py", "nodes_perpneg.py", "nodes_stable3d.py", @@ -2262,9 +2275,11 @@ def init_builtin_extra_nodes(): "nodes_optimalsteps.py", "nodes_hidream.py", "nodes_fresca.py", + "nodes_apg.py", "nodes_preview_any.py", "nodes_ace.py", "nodes_string.py", + "nodes_camera_trajectory.py", ] import_failed = [] @@ -2289,6 +2304,10 @@ def init_builtin_api_nodes(): "nodes_pixverse.py", "nodes_stability.py", "nodes_pika.py", + "nodes_runway.py", + "nodes_tripo.py", + "nodes_rodin.py", + "nodes_gemini.py", ] if not load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"): diff --git a/pyproject.toml b/pyproject.toml index 80061b39a..c572ad4c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.34" +version = "0.3.41" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index 8f7a78984..336ec9d57 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ -comfyui-frontend-package==1.19.9 -comfyui-workflow-templates==0.1.14 +comfyui-frontend-package==1.21.7 +comfyui-workflow-templates==0.1.28 +comfyui-embedded-docs==0.2.2 torch torchsde torchvision @@ -17,6 +18,8 @@ Pillow scipy tqdm psutil +alembic +SQLAlchemy #non essential dependencies: kornia>=0.7.1 @@ -24,3 +27,4 @@ spandrel soundfile av>=14.2.0 pydantic~=2.0 +pydantic-settings~=2.0 diff --git a/script_examples/basic_api_example.py b/script_examples/basic_api_example.py index c916e6cb9..9128420c4 100644 --- a/script_examples/basic_api_example.py +++ b/script_examples/basic_api_example.py @@ -101,6 +101,14 @@ prompt_text = """ def queue_prompt(prompt): p = {"prompt": prompt} + + # If the workflow contains API nodes, you can add a Comfy API key to the `extra_data`` field of the payload. + # p["extra_data"] = { + # "api_key_comfy_org": "comfyui-87d01e28d*******************************************************" # replace with real key + # } + # See: https://docs.comfy.org/tutorials/api-nodes/overview + # Generate a key here: https://platform.comfy.org/login + data = json.dumps(p).encode('utf-8') req = request.Request("http://127.0.0.1:8188/prompt", data=data) request.urlopen(req) diff --git a/server.py b/server.py index cb1c6a8fd..878b5eeb1 100644 --- a/server.py +++ b/server.py @@ -29,6 +29,7 @@ import comfy.model_management import node_helpers from comfyui_version import __version__ from app.frontend_management import FrontendManager + from app.user_manager import UserManager from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager @@ -159,7 +160,7 @@ class PromptServer(): self.custom_node_manager = CustomNodeManager() self.internal_routes = InternalRoutes(self) self.supports = ["custom_nodes_from_web"] - self.prompt_queue = None + self.prompt_queue = execution.PromptQueue(self) self.loop = loop self.messages = asyncio.Queue() self.client_session:Optional[aiohttp.ClientSession] = None @@ -226,7 +227,7 @@ class PromptServer(): return response @routes.get("/embeddings") - def get_embeddings(self): + def get_embeddings(request): embeddings = folder_paths.get_filename_list("embeddings") return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings))) @@ -282,7 +283,6 @@ class PromptServer(): a.update(f.read()) b.update(image.file.read()) image.file.seek(0) - f.close() return a.hexdigest() == b.hexdigest() return False @@ -390,7 +390,7 @@ class PromptServer(): async def view_image(request): if "filename" in request.rel_url.query: filename = request.rel_url.query["filename"] - filename,output_dir = folder_paths.annotated_filepath(filename) + filename, output_dir = folder_paths.annotated_filepath(filename) if not filename: return web.Response(status=400) @@ -476,9 +476,8 @@ class PromptServer(): # Get content type from mimetype, defaulting to 'application/octet-stream' content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream' - # For security, force certain extensions to download instead of display - file_extension = os.path.splitext(filename)[1].lower() - if file_extension in {'.html', '.htm', '.js', '.css'}: + # For security, force certain mimetypes to download instead of display + if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}: content_type = 'application/octet-stream' # Forces download return web.FileResponse( @@ -621,7 +620,7 @@ class PromptServer(): @routes.get("/queue") async def get_queue(request): queue_info = {} - current_queue = self.prompt_queue.get_current_queue() + current_queue = self.prompt_queue.get_current_queue_volatile() queue_info['queue_running'] = current_queue[0] queue_info['queue_pending'] = current_queue[1] return web.json_response(queue_info) @@ -746,6 +745,13 @@ class PromptServer(): web.static('/templates', workflow_templates_path) ]) + # Serve embedded documentation from the package + embedded_docs_path = FrontendManager.embedded_docs_path() + if embedded_docs_path: + self.app.add_routes([ + web.static('/docs', embedded_docs_path) + ]) + self.app.add_routes([ web.static('/', self.web_root), ]) @@ -782,7 +788,7 @@ class PromptServer(): if hasattr(Image, 'Resampling'): resampling = Image.Resampling.BILINEAR else: - resampling = Image.ANTIALIAS + resampling = Image.Resampling.LANCZOS image = ImageOps.contain(image, (max_size, max_size), resampling) type_num = 1 diff --git a/tests-unit/comfy_api_test/video_types_test.py b/tests-unit/comfy_api_test/video_types_test.py new file mode 100644 index 000000000..b25fcb1ca --- /dev/null +++ b/tests-unit/comfy_api_test/video_types_test.py @@ -0,0 +1,239 @@ +import pytest +import torch +import tempfile +import os +import av +import io +from fractions import Fraction +from comfy_api.input_impl.video_types import VideoFromFile, VideoFromComponents +from comfy_api.util.video_types import VideoComponents +from comfy_api.input.basic_types import AudioInput +from av.error import InvalidDataError + +EPSILON = 0.0001 + + +@pytest.fixture +def sample_images(): + """3-frame 2x2 RGB video tensor""" + return torch.rand(3, 2, 2, 3) + + +@pytest.fixture +def sample_audio(): + """Stereo audio with 44.1kHz sample rate""" + return AudioInput( + { + "waveform": torch.rand(1, 2, 1000), + "sample_rate": 44100, + } + ) + + +@pytest.fixture +def video_components(sample_images, sample_audio): + """VideoComponents with images, audio, and metadata""" + return VideoComponents( + images=sample_images, + audio=sample_audio, + frame_rate=Fraction(30), + metadata={"test": "metadata"}, + ) + + +def create_test_video(width=4, height=4, frames=3, fps=30): + """Helper to create a temporary video file""" + tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) + with av.open(tmp.name, mode="w") as container: + stream = container.add_stream("h264", rate=fps) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + + for i in range(frames): + frame = av.VideoFrame.from_ndarray( + torch.ones(height, width, 3, dtype=torch.uint8).numpy() * (i * 85), + format="rgb24", + ) + frame = frame.reformat(format="yuv420p") + packet = stream.encode(frame) + container.mux(packet) + + # Flush + packet = stream.encode(None) + container.mux(packet) + + return tmp.name + + +@pytest.fixture +def simple_video_file(): + """4x4 video with 3 frames at 30fps""" + file_path = create_test_video() + yield file_path + os.unlink(file_path) + + +def test_video_from_components_get_duration(video_components): + """Duration calculated correctly from frame count and frame rate""" + video = VideoFromComponents(video_components) + duration = video.get_duration() + + expected_duration = 3.0 / 30.0 + assert duration == pytest.approx(expected_duration) + + +def test_video_from_components_get_duration_different_frame_rates(sample_images): + """Duration correct for different frame rates including fractional""" + # Test with 60 fps + components_60fps = VideoComponents(images=sample_images, frame_rate=Fraction(60)) + video_60fps = VideoFromComponents(components_60fps) + assert video_60fps.get_duration() == pytest.approx(3.0 / 60.0) + + # Test with fractional frame rate (23.976fps) + components_frac = VideoComponents( + images=sample_images, frame_rate=Fraction(24000, 1001) + ) + video_frac = VideoFromComponents(components_frac) + expected_frac = 3.0 / (24000.0 / 1001.0) + assert video_frac.get_duration() == pytest.approx(expected_frac) + + +def test_video_from_components_get_duration_empty_video(): + """Duration is zero for empty video""" + empty_components = VideoComponents( + images=torch.zeros(0, 2, 2, 3), frame_rate=Fraction(30) + ) + video = VideoFromComponents(empty_components) + assert video.get_duration() == 0.0 + + +def test_video_from_components_get_dimensions(video_components): + """Dimensions returned correctly from image tensor shape""" + video = VideoFromComponents(video_components) + width, height = video.get_dimensions() + assert width == 2 + assert height == 2 + + +def test_video_from_file_get_duration(simple_video_file): + """Duration extracted from file metadata""" + video = VideoFromFile(simple_video_file) + duration = video.get_duration() + assert duration == pytest.approx(0.1, abs=0.01) + + +def test_video_from_file_get_dimensions(simple_video_file): + """Dimensions read from stream without decoding frames""" + video = VideoFromFile(simple_video_file) + width, height = video.get_dimensions() + assert width == 4 + assert height == 4 + + +def test_video_from_file_bytesio_input(): + """VideoFromFile works with BytesIO input""" + buffer = io.BytesIO() + with av.open(buffer, mode="w", format="mp4") as container: + stream = container.add_stream("h264", rate=30) + stream.width = 2 + stream.height = 2 + stream.pix_fmt = "yuv420p" + + frame = av.VideoFrame.from_ndarray( + torch.zeros(2, 2, 3, dtype=torch.uint8).numpy(), format="rgb24" + ) + frame = frame.reformat(format="yuv420p") + packet = stream.encode(frame) + container.mux(packet) + packet = stream.encode(None) + container.mux(packet) + + buffer.seek(0) + video = VideoFromFile(buffer) + + assert video.get_dimensions() == (2, 2) + assert video.get_duration() == pytest.approx(1 / 30, abs=0.01) + + +def test_video_from_file_invalid_file_error(): + """InvalidDataError raised for non-video files""" + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp: + tmp.write(b"not a video file") + tmp.flush() + tmp_name = tmp.name + + try: + with pytest.raises(InvalidDataError): + video = VideoFromFile(tmp_name) + video.get_dimensions() + finally: + os.unlink(tmp_name) + + +def test_video_from_file_audio_only_error(): + """ValueError raised for audio-only files""" + with tempfile.NamedTemporaryFile(suffix=".m4a", delete=False) as tmp: + tmp_name = tmp.name + + try: + with av.open(tmp_name, mode="w") as container: + stream = container.add_stream("aac", rate=44100) + stream.sample_rate = 44100 + stream.format = "fltp" + + audio_data = torch.zeros(1, 1024).numpy() + audio_frame = av.AudioFrame.from_ndarray( + audio_data, format="fltp", layout="mono" + ) + audio_frame.sample_rate = 44100 + audio_frame.pts = 0 + packet = stream.encode(audio_frame) + container.mux(packet) + + for packet in stream.encode(None): + container.mux(packet) + + with pytest.raises(ValueError, match="No video stream found"): + video = VideoFromFile(tmp_name) + video.get_dimensions() + finally: + os.unlink(tmp_name) + + +def test_single_frame_video(): + """Single frame video has correct duration""" + components = VideoComponents( + images=torch.rand(1, 10, 10, 3), frame_rate=Fraction(1) + ) + video = VideoFromComponents(components) + assert video.get_duration() == 1.0 + + +@pytest.mark.parametrize( + "frame_rate,expected_fps", + [ + (Fraction(24000, 1001), 24000 / 1001), + (Fraction(30000, 1001), 30000 / 1001), + (Fraction(25, 1), 25.0), + (Fraction(50, 2), 25.0), + ], +) +def test_fractional_frame_rates(frame_rate, expected_fps): + """Duration calculated correctly for various fractional frame rates""" + components = VideoComponents(images=torch.rand(100, 4, 4, 3), frame_rate=frame_rate) + video = VideoFromComponents(components) + duration = video.get_duration() + expected_duration = 100.0 / expected_fps + assert duration == pytest.approx(expected_duration) + + +def test_duration_consistency(video_components): + """get_duration() consistent with manual calculation from components""" + video = VideoFromComponents(video_components) + + duration = video.get_duration() + components = video.get_components() + manual_duration = float(components.images.shape[0] / components.frame_rate) + + assert duration == pytest.approx(manual_duration) diff --git a/tests-unit/comfy_extras_test/__init__.py b/tests-unit/comfy_extras_test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests-unit/comfy_extras_test/image_stitch_test.py b/tests-unit/comfy_extras_test/image_stitch_test.py new file mode 100644 index 000000000..b5a0f022c --- /dev/null +++ b/tests-unit/comfy_extras_test/image_stitch_test.py @@ -0,0 +1,243 @@ +import torch +from unittest.mock import patch, MagicMock + +# Mock nodes module to prevent CUDA initialization during import +mock_nodes = MagicMock() +mock_nodes.MAX_RESOLUTION = 16384 + +# Mock server module for PromptServer +mock_server = MagicMock() + +with patch.dict('sys.modules', {'nodes': mock_nodes, 'server': mock_server}): + from comfy_extras.nodes_images import ImageStitch + + +class TestImageStitch: + + def create_test_image(self, batch_size=1, height=64, width=64, channels=3): + """Helper to create test images with specific dimensions""" + return torch.rand(batch_size, height, width, channels) + + def test_no_image2_passthrough(self): + """Test that when image2 is None, image1 is returned unchanged""" + node = ImageStitch() + image1 = self.create_test_image() + + result = node.stitch(image1, "right", True, 0, "white", image2=None) + + assert len(result) == 1 + assert torch.equal(result[0], image1) + + def test_basic_horizontal_stitch_right(self): + """Test basic horizontal stitching to the right""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=24) + + result = node.stitch(image1, "right", False, 0, "white", image2) + + assert result[0].shape == (1, 32, 56, 3) # 32 + 24 width + + def test_basic_horizontal_stitch_left(self): + """Test basic horizontal stitching to the left""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=24) + + result = node.stitch(image1, "left", False, 0, "white", image2) + + assert result[0].shape == (1, 32, 56, 3) # 24 + 32 width + + def test_basic_vertical_stitch_down(self): + """Test basic vertical stitching downward""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=24, width=32) + + result = node.stitch(image1, "down", False, 0, "white", image2) + + assert result[0].shape == (1, 56, 32, 3) # 32 + 24 height + + def test_basic_vertical_stitch_up(self): + """Test basic vertical stitching upward""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=24, width=32) + + result = node.stitch(image1, "up", False, 0, "white", image2) + + assert result[0].shape == (1, 56, 32, 3) # 24 + 32 height + + def test_size_matching_horizontal(self): + """Test size matching for horizontal concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=64, width=64) + image2 = self.create_test_image(height=32, width=32) # Different aspect ratio + + result = node.stitch(image1, "right", True, 0, "white", image2) + + # image2 should be resized to match image1's height (64) with preserved aspect ratio + expected_width = 64 + 64 # original + resized (32*64/32 = 64) + assert result[0].shape == (1, 64, expected_width, 3) + + def test_size_matching_vertical(self): + """Test size matching for vertical concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=64, width=64) + image2 = self.create_test_image(height=32, width=32) + + result = node.stitch(image1, "down", True, 0, "white", image2) + + # image2 should be resized to match image1's width (64) with preserved aspect ratio + expected_height = 64 + 64 # original + resized (32*64/32 = 64) + assert result[0].shape == (1, expected_height, 64, 3) + + def test_padding_for_mismatched_heights_horizontal(self): + """Test padding when heights don't match in horizontal concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=64, width=32) + image2 = self.create_test_image(height=48, width=24) # Shorter height + + result = node.stitch(image1, "right", False, 0, "white", image2) + + # Both images should be padded to height 64 + assert result[0].shape == (1, 64, 56, 3) # 32 + 24 width, max(64,48) height + + def test_padding_for_mismatched_widths_vertical(self): + """Test padding when widths don't match in vertical concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=64) + image2 = self.create_test_image(height=24, width=48) # Narrower width + + result = node.stitch(image1, "down", False, 0, "white", image2) + + # Both images should be padded to width 64 + assert result[0].shape == (1, 56, 64, 3) # 32 + 24 height, max(64,48) width + + def test_spacing_horizontal(self): + """Test spacing addition in horizontal concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=24) + spacing_width = 16 + + result = node.stitch(image1, "right", False, spacing_width, "white", image2) + + # Expected width: 32 + 16 (spacing) + 24 = 72 + assert result[0].shape == (1, 32, 72, 3) + + def test_spacing_vertical(self): + """Test spacing addition in vertical concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=24, width=32) + spacing_width = 16 + + result = node.stitch(image1, "down", False, spacing_width, "white", image2) + + # Expected height: 32 + 16 (spacing) + 24 = 72 + assert result[0].shape == (1, 72, 32, 3) + + def test_spacing_color_values(self): + """Test that spacing colors are applied correctly""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=32) + + # Test white spacing + result_white = node.stitch(image1, "right", False, 16, "white", image2) + # Check that spacing region contains white values (close to 1.0) + spacing_region = result_white[0][:, :, 32:48, :] # Middle 16 pixels + assert torch.all(spacing_region >= 0.9) # Should be close to white + + # Test black spacing + result_black = node.stitch(image1, "right", False, 16, "black", image2) + spacing_region = result_black[0][:, :, 32:48, :] + assert torch.all(spacing_region <= 0.1) # Should be close to black + + def test_odd_spacing_width_made_even(self): + """Test that odd spacing widths are made even""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=32) + + # Use odd spacing width + result = node.stitch(image1, "right", False, 15, "white", image2) + + # Should be made even (16), so total width = 32 + 16 + 32 = 80 + assert result[0].shape == (1, 32, 80, 3) + + def test_batch_size_matching(self): + """Test that different batch sizes are handled correctly""" + node = ImageStitch() + image1 = self.create_test_image(batch_size=2, height=32, width=32) + image2 = self.create_test_image(batch_size=1, height=32, width=32) + + result = node.stitch(image1, "right", False, 0, "white", image2) + + # Should match larger batch size + assert result[0].shape == (2, 32, 64, 3) + + def test_channel_matching_rgb_to_rgba(self): + """Test that channel differences are handled (RGB + alpha)""" + node = ImageStitch() + image1 = self.create_test_image(channels=3) # RGB + image2 = self.create_test_image(channels=4) # RGBA + + result = node.stitch(image1, "right", False, 0, "white", image2) + + # Should have 4 channels (RGBA) + assert result[0].shape[-1] == 4 + + def test_channel_matching_rgba_to_rgb(self): + """Test that channel differences are handled (RGBA + RGB)""" + node = ImageStitch() + image1 = self.create_test_image(channels=4) # RGBA + image2 = self.create_test_image(channels=3) # RGB + + result = node.stitch(image1, "right", False, 0, "white", image2) + + # Should have 4 channels (RGBA) + assert result[0].shape[-1] == 4 + + def test_all_color_options(self): + """Test all available color options""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=32) + + colors = ["white", "black", "red", "green", "blue"] + + for color in colors: + result = node.stitch(image1, "right", False, 16, color, image2) + assert result[0].shape == (1, 32, 80, 3) # Basic shape check + + def test_all_directions(self): + """Test all direction options""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=32) + + directions = ["right", "left", "up", "down"] + + for direction in directions: + result = node.stitch(image1, direction, False, 0, "white", image2) + assert result[0].shape == (1, 32, 64, 3) if direction in ["right", "left"] else (1, 64, 32, 3) + + def test_batch_size_channel_spacing_integration(self): + """Test integration of batch matching, channel matching, size matching, and spacings""" + node = ImageStitch() + image1 = self.create_test_image(batch_size=2, height=64, width=48, channels=3) + image2 = self.create_test_image(batch_size=1, height=32, width=32, channels=4) + + result = node.stitch(image1, "right", True, 8, "red", image2) + + # Should handle: batch matching, size matching, channel matching, spacing + assert result[0].shape[0] == 2 # Batch size matched + assert result[0].shape[-1] == 4 # Channels matched to max + assert result[0].shape[1] == 64 # Height from image1 (size matching) + # Width should be: 48 + 8 (spacing) + resized_image2_width + expected_image2_width = int(64 * (32/32)) # Resized to height 64 + expected_total_width = 48 + 8 + expected_image2_width + assert result[0].shape[2] == expected_total_width + diff --git a/tests-unit/folder_paths_test/misc_test.py b/tests-unit/folder_paths_test/misc_test.py new file mode 100644 index 000000000..fcf667453 --- /dev/null +++ b/tests-unit/folder_paths_test/misc_test.py @@ -0,0 +1,51 @@ +import pytest +import os +import tempfile +from folder_paths import get_input_subfolders, set_input_directory + +@pytest.fixture(scope="module") +def mock_folder_structure(): + with tempfile.TemporaryDirectory() as temp_dir: + # Create a nested folder structure + folders = [ + "folder1", + "folder1/subfolder1", + "folder1/subfolder2", + "folder2", + "folder2/deep", + "folder2/deep/nested", + "empty_folder" + ] + + # Create the folders + for folder in folders: + os.makedirs(os.path.join(temp_dir, folder)) + + # Add some files to test they're not included + with open(os.path.join(temp_dir, "root_file.txt"), "w") as f: + f.write("test") + with open(os.path.join(temp_dir, "folder1", "test.txt"), "w") as f: + f.write("test") + + set_input_directory(temp_dir) + yield temp_dir + + +def test_gets_all_folders(mock_folder_structure): + folders = get_input_subfolders() + expected = ["folder1", "folder1/subfolder1", "folder1/subfolder2", + "folder2", "folder2/deep", "folder2/deep/nested", "empty_folder"] + assert sorted(folders) == sorted(expected) + + +def test_handles_nonexistent_input_directory(): + with tempfile.TemporaryDirectory() as temp_dir: + nonexistent = os.path.join(temp_dir, "nonexistent") + set_input_directory(nonexistent) + assert get_input_subfolders() == [] + + +def test_empty_input_directory(): + with tempfile.TemporaryDirectory() as temp_dir: + set_input_directory(temp_dir) + assert get_input_subfolders() == [] # Empty since we don't include root diff --git a/utils/install_util.py b/utils/install_util.py new file mode 100644 index 000000000..0f59bcf91 --- /dev/null +++ b/utils/install_util.py @@ -0,0 +1,18 @@ +from pathlib import Path +import sys + +# The path to the requirements.txt file +requirements_path = Path(__file__).parents[1] / "requirements.txt" + + +def get_missing_requirements_message(): + """The warning message to display when a package is missing.""" + + extra = "" + if sys.flags.no_user_site: + extra = "-s " + return f""" +Please install the updated requirements.txt file by running: +{sys.executable} {extra}-m pip install -r {requirements_path} +If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem. +""".strip()