diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 39d1992d7..69ce998eb 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -15,6 +15,14 @@ body: steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen. If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first. + - type: checkboxes + id: custom-nodes-test + attributes: + label: Custom Node Testing + description: Please confirm you have tried to reproduce the issue with all custom nodes disabled. + options: + - label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help) + required: true - type: textarea attributes: label: Expected Behavior diff --git a/.github/ISSUE_TEMPLATE/user-support.yml b/.github/ISSUE_TEMPLATE/user-support.yml index df28804c6..50657d493 100644 --- a/.github/ISSUE_TEMPLATE/user-support.yml +++ b/.github/ISSUE_TEMPLATE/user-support.yml @@ -11,6 +11,14 @@ body: **2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics. If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first. + - type: checkboxes + id: custom-nodes-test + attributes: + label: Custom Node Testing + description: Please confirm you have tried to reproduce the issue with all custom nodes disabled. + options: + - label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help) + required: true - type: textarea attributes: label: Your question diff --git a/README.md b/README.md index 47514d1b4..0de4a6bb5 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ [![Website][website-shield]][website-url] [![Dynamic JSON Badge][discord-shield]][discord-url] +[![Twitter][twitter-shield]][twitter-url] [![Matrix][matrix-shield]][matrix-url]
[![][github-release-shield]][github-release-link] @@ -20,6 +21,8 @@ [discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total [discord-url]: https://www.comfy.org/discord +[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI +[twitter-url]: https://x.com/ComfyUI [github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver [github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases @@ -62,12 +65,13 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/) - [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/) - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) + - [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/) - Video Models - [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/) - [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/) - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) + - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) - Audio Models - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) @@ -95,7 +99,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) - Latent previews with [TAESD](#how-to-show-high-quality-previews) - Starts up very fast. -- Works fully offline: will never download anything. +- Works fully offline: core will never download anything unless you want to. +- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview). - [Config file](extra_model_paths.yaml.example) to set the search paths for models. Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/) diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 000000000..12f18712f --- /dev/null +++ b/alembic.ini @@ -0,0 +1,84 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = alembic_db + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic_db/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +# version_path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +version_path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = sqlite:///user/comfyui.db + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME diff --git a/alembic_db/README.md b/alembic_db/README.md new file mode 100644 index 000000000..3b808c7ca --- /dev/null +++ b/alembic_db/README.md @@ -0,0 +1,4 @@ +## Generate new revision + +1. Update models in `/app/database/models.py` +2. Run `alembic revision --autogenerate -m "{your message}"` diff --git a/alembic_db/env.py b/alembic_db/env.py new file mode 100644 index 000000000..4d7770679 --- /dev/null +++ b/alembic_db/env.py @@ -0,0 +1,64 @@ +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + + +from app.database.models import Base +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + Calls to context.execute() here emit the given string to the + script output. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + In this scenario we need to create an Engine + and associate a connection with the context. + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic_db/script.py.mako b/alembic_db/script.py.mako new file mode 100644 index 000000000..480b130d6 --- /dev/null +++ b/alembic_db/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/app/database/db.py b/app/database/db.py new file mode 100644 index 000000000..1de8b80ed --- /dev/null +++ b/app/database/db.py @@ -0,0 +1,112 @@ +import logging +import os +import shutil +from app.logger import log_startup_warning +from utils.install_util import get_missing_requirements_message +from comfy.cli_args import args + +_DB_AVAILABLE = False +Session = None + + +try: + from alembic import command + from alembic.config import Config + from alembic.runtime.migration import MigrationContext + from alembic.script import ScriptDirectory + from sqlalchemy import create_engine + from sqlalchemy.orm import sessionmaker + + _DB_AVAILABLE = True +except ImportError as e: + log_startup_warning( + f""" +------------------------------------------------------------------------ +Error importing dependencies: {e} +{get_missing_requirements_message()} +This error is happening because ComfyUI now uses a local sqlite database. +------------------------------------------------------------------------ +""".strip() + ) + + +def dependencies_available(): + """ + Temporary function to check if the dependencies are available + """ + return _DB_AVAILABLE + + +def can_create_session(): + """ + Temporary function to check if the database is available to create a session + During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created + """ + return dependencies_available() and Session is not None + + +def get_alembic_config(): + root_path = os.path.join(os.path.dirname(__file__), "../..") + config_path = os.path.abspath(os.path.join(root_path, "alembic.ini")) + scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) + + config = Config(config_path) + config.set_main_option("script_location", scripts_path) + config.set_main_option("sqlalchemy.url", args.database_url) + + return config + + +def get_db_path(): + url = args.database_url + if url.startswith("sqlite:///"): + return url.split("///")[1] + else: + raise ValueError(f"Unsupported database URL '{url}'.") + + +def init_db(): + db_url = args.database_url + logging.debug(f"Database URL: {db_url}") + db_path = get_db_path() + db_exists = os.path.exists(db_path) + + config = get_alembic_config() + + # Check if we need to upgrade + engine = create_engine(db_url) + conn = engine.connect() + + context = MigrationContext.configure(conn) + current_rev = context.get_current_revision() + + script = ScriptDirectory.from_config(config) + target_rev = script.get_current_head() + + if target_rev is None: + logging.warning("No target revision found.") + elif current_rev != target_rev: + # Backup the database pre upgrade + backup_path = db_path + ".bkp" + if db_exists: + shutil.copy(db_path, backup_path) + else: + backup_path = None + + try: + command.upgrade(config, target_rev) + logging.info(f"Database upgraded from {current_rev} to {target_rev}") + except Exception as e: + if backup_path: + # Restore the database from backup if upgrade fails + shutil.copy(backup_path, db_path) + os.remove(backup_path) + logging.exception("Error upgrading database: ") + raise e + + global Session + Session = sessionmaker(bind=engine) + + +def create_session(): + return Session() diff --git a/app/database/models.py b/app/database/models.py new file mode 100644 index 000000000..6facfb8f2 --- /dev/null +++ b/app/database/models.py @@ -0,0 +1,14 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +def to_dict(obj): + fields = obj.__table__.columns.keys() + return { + field: (val.to_dict() if hasattr(val, "to_dict") else val) + for field in fields + if (val := getattr(obj, field)) + } + +# TODO: Define models here diff --git a/app/frontend_management.py b/app/frontend_management.py index d9ef8c921..001ebbecb 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -16,26 +16,17 @@ from importlib.metadata import version import requests from typing_extensions import NotRequired +from utils.install_util import get_missing_requirements_message, requirements_path + from comfy.cli_args import DEFAULT_VERSION_STRING import app.logger -# The path to the requirements.txt file -req_path = Path(__file__).parents[1] / "requirements.txt" - def frontend_install_warning_message(): - """The warning message to display when the frontend version is not up to date.""" - - extra = "" - if sys.flags.no_user_site: - extra = "-s " return f""" -Please install the updated requirements.txt file by running: -{sys.executable} {extra}-m pip install -r {req_path} +{get_missing_requirements_message()} This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead. - -If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem """.strip() @@ -48,7 +39,7 @@ def check_frontend_version(): try: frontend_version_str = version("comfyui-frontend-package") frontend_version = parse_version(frontend_version_str) - with open(req_path, "r", encoding="utf-8") as f: + with open(requirements_path, "r", encoding="utf-8") as f: required_frontend = parse_version(f.readline().split("=")[-1]) if frontend_version < required_frontend: app.logger.log_startup_warning( @@ -121,9 +112,22 @@ class FrontEndProvider: response.raise_for_status() # Raises an HTTPError if the response was an error return response.json() + @cached_property + def latest_prerelease(self) -> Release: + """Get the latest pre-release version - even if it's older than the latest release""" + release = [release for release in self.all_releases if release["prerelease"]] + + if not release: + raise ValueError("No pre-releases found") + + # GitHub returns releases in reverse chronological order, so first is latest + return release[0] + def get_release(self, version: str) -> Release: if version == "latest": return self.latest_release + elif version == "prerelease": + return self.latest_prerelease else: for release in self.all_releases: if release["tag_name"] in [version, f"v{version}"]: @@ -230,7 +234,7 @@ comfyui-workflow-templates is not installed. Raises: argparse.ArgumentTypeError: If the version string is invalid. """ - VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$" + VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+[-._a-zA-Z0-9]*|latest|prerelease)$" match_result = re.match(VERSION_PATTERN, value) if match_result is None: raise argparse.ArgumentTypeError(f"Invalid version string: {value}") diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 8db5e3e4d..1609ba308 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -203,6 +203,11 @@ parser.add_argument( help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)", ) +database_default_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") +) +parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") + if comfy.options.args_parsing: args = parser.parse_args() else: diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 470eb9fdb..071b98332 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -37,6 +37,8 @@ class IO(StrEnum): CONTROL_NET = "CONTROL_NET" VAE = "VAE" MODEL = "MODEL" + LORA_MODEL = "LORA_MODEL" + LOSS_MAP = "LOSS_MAP" CLIP_VISION = "CLIP_VISION" CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT" STYLE_MODEL = "STYLE_MODEL" diff --git a/comfy/controlnet.py b/comfy/controlnet.py index d1ffeea36..c782c827f 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -433,8 +433,9 @@ class ControlLora(ControlNet): pass for k in self.control_weights: - if k not in {"lora_controlnet"}: - comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) + if (k not in {"lora_controlnet"}): + if (k.endswith(".up") or k.endswith(".down") or k.endswith(".weight") or k.endswith(".bias")) and ("__" not in k): + comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) def copy(self): c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index fbdf6f554..8030048fc 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,4 +1,5 @@ import math +from functools import partial from scipy import integrate import torch @@ -142,6 +143,33 @@ class BrownianTreeNoiseSampler: return self.tree(t0, t1) / (t1 - t0).abs().sqrt() +def sigma_to_half_log_snr(sigma, model_sampling): + """Convert sigma to half-logSNR log(alpha_t / sigma_t).""" + if isinstance(model_sampling, comfy.model_sampling.CONST): + # log((1 - t) / t) = log((1 - sigma) / sigma) + return sigma.logit().neg() + return sigma.log().neg() + + +def half_log_snr_to_sigma(half_log_snr, model_sampling): + """Convert half-logSNR log(alpha_t / sigma_t) to sigma.""" + if isinstance(model_sampling, comfy.model_sampling.CONST): + # 1 / (1 + exp(half_log_snr)) + return half_log_snr.neg().sigmoid() + return half_log_snr.neg().exp() + + +def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4): + """Adjust the first sigma to avoid invalid logSNR.""" + if len(sigmas) <= 1: + return sigmas + if isinstance(model_sampling, comfy.model_sampling.CONST): + if sigmas[0] >= 1: + sigmas = sigmas.clone() + sigmas[0] = model_sampling.percent_to_sigma(percent_offset) + return sigmas + + @torch.no_grad() def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" @@ -753,6 +781,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No old_denoised = denoised return x + @torch.no_grad() def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): """DPM-Solver++(2M) SDE.""" @@ -768,9 +797,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + old_denoised = None - h_last = None - h = None + h, h_last = None, None for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -781,26 +813,29 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl x = denoised else: # DPM-Solver++(2M) SDE - t, s = -sigmas[i].log(), -sigmas[i + 1].log() - h = s - t - eta_h = eta * h + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + h_eta = h * (eta + 1) - x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised + alpha_t = sigmas[i + 1] * lambda_t.exp() + + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised if old_denoised is not None: r = h_last / h if solver_type == 'heun': - x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) + x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised) elif solver_type == 'midpoint': - x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) + x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised) - if eta: - x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + if eta > 0 and s_noise > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise old_denoised = denoised h_last = h return x + @torch.no_grad() def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """DPM-Solver++(3M) SDE.""" @@ -814,6 +849,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + denoised_1, denoised_2 = None, None h, h_1, h_2 = None, None, None @@ -825,13 +864,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl # Denoising step x = denoised else: - t, s = -sigmas[i].log(), -sigmas[i + 1].log() - h = s - t + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s h_eta = h * (eta + 1) - x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised + alpha_t = sigmas[i + 1] * lambda_t.exp() + + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised if h_2 is not None: + # DPM-Solver++(3M) SDE r0 = h_1 / h r1 = h_2 / h d1_0 = (denoised - denoised_1) / r0 @@ -840,20 +882,22 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl d2 = (d1_0 - d1_1) / (r0 + r1) phi_2 = h_eta.neg().expm1() / h_eta + 1 phi_3 = phi_2 / h_eta - 0.5 - x = x + phi_2 * d1 - phi_3 * d2 + x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2 elif h_1 is not None: + # DPM-Solver++(2M) SDE r = h_1 / h d = (denoised - denoised_1) / r phi_2 = h_eta.neg().expm1() / h_eta + 1 - x = x + phi_2 * d + x = x + (alpha_t * phi_2) * d - if eta: + if eta > 0 and s_noise > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise denoised_1, denoised_2 = denoised, denoised_1 h_1, h_2 = h, h_1 return x + @torch.no_grad() def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): if len(sigmas) <= 1: @@ -863,6 +907,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) + @torch.no_grad() def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): if len(sigmas) <= 1: @@ -872,6 +917,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + @torch.no_grad() def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): if len(sigmas) <= 1: @@ -1449,12 +1495,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None old_denoised = denoised return x + @torch.no_grad() def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): - ''' - SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2 - Arxiv: https://arxiv.org/abs/2305.14267 - ''' + """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. + arXiv: https://arxiv.org/abs/2305.14267 + """ extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler @@ -1462,6 +1508,11 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non inject_noise = eta > 0 and s_noise > 0 + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: @@ -1469,80 +1520,96 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non if sigmas[i + 1] == 0: x = denoised else: - t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() - h = t_next - t + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s h_eta = h * (eta + 1) - s = t + r * h + lambda_s_1 = lambda_s + r * h fac = 1 / (2 * r) - sigma_s = s.neg().exp() + sigma_s_1 = sigma_fn(lambda_s_1) + + # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1() if inject_noise: + # 0 < r < 1 noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt() - noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() - noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1]) + noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt() + noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1]) # Step 1 - x_2 = (coeff_1 + 1) * x - coeff_1 * denoised - if inject_noise: - x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise - denoised_2 = model(x_2, sigma_s * s_in, **extra_args) - - # Step 2 - denoised_d = (1 - fac) * denoised + fac * denoised_2 - x = (coeff_2 + 1) * x - coeff_2 * denoised_d - if inject_noise: - x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise - return x - -@torch.no_grad() -def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): - ''' - SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3 - Arxiv: https://arxiv.org/abs/2305.14267 - ''' - extra_args = {} if extra_args is None else extra_args - seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_in = x.new_ones([x.shape[0]]) - - inject_noise = eta > 0 and s_noise > 0 - - for i in trange(len(sigmas) - 1, disable=disable): - denoised = model(x, sigmas[i] * s_in, **extra_args) - if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - if sigmas[i + 1] == 0: - x = denoised - else: - t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() - h = t_next - t - h_eta = h * (eta + 1) - s_1 = t + r_1 * h - s_2 = t + r_2 * h - sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp() - - coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() - if inject_noise: - noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() - noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt() - noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() - noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) - - # Step 1 - x_2 = (coeff_1 + 1) * x - coeff_1 * denoised + x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised if inject_noise: x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) # Step 2 - x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) + denoised_d = (1 - fac) * denoised + fac * denoised_2 + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d + if inject_noise: + x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise + return x + + +@torch.no_grad() +def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): + """SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3. + arXiv: https://arxiv.org/abs/2305.14267 + """ + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + inject_noise = eta > 0 and s_noise > 0 + + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + x = denoised + else: + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + h_eta = h * (eta + 1) + lambda_s_1 = lambda_s + r_1 * h + lambda_s_2 = lambda_s + r_2 * h + sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2) + + # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_s_2 = sigma_s_2 * lambda_s_2.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() + + coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() + if inject_noise: + # 0 < r_1 < r_2 < 1 + noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() + noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt() + noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt() + noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) + + # Step 1 + x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised + if inject_noise: + x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise + denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) + + # Step 2 + x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) if inject_noise: x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) # Step 3 - x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) if inject_noise: x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise return x diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index a12f892d2..5c4356a3f 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -26,16 +26,6 @@ from torch import nn from comfy.ldm.modules.attention import optimized_attention -def apply_rotary_pos_emb( - t: torch.Tensor, - freqs: torch.Tensor, -) -> torch.Tensor: - t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() - t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] - t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) - return t_out - - def get_normalization(name: str, channels: int, weight_args={}, operations=None): if name == "I": return nn.Identity() diff --git a/comfy/ldm/cosmos/position_embedding.py b/comfy/ldm/cosmos/position_embedding.py index 4d6a58dba..c925811d4 100644 --- a/comfy/ldm/cosmos/position_embedding.py +++ b/comfy/ldm/cosmos/position_embedding.py @@ -66,15 +66,16 @@ class VideoRopePosition3DEmb(VideoPositionEmb): h_extrapolation_ratio: float = 1.0, w_extrapolation_ratio: float = 1.0, t_extrapolation_ratio: float = 1.0, + enable_fps_modulation: bool = True, device=None, **kwargs, # used for compatibility with other positional embeddings; unused in this class ): del kwargs super().__init__() - self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device)) self.base_fps = base_fps self.max_h = len_h self.max_w = len_w + self.enable_fps_modulation = enable_fps_modulation dim = head_dim dim_h = dim // 6 * 2 @@ -132,21 +133,19 @@ class VideoRopePosition3DEmb(VideoPositionEmb): temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device)) B, T, H, W, _ = B_T_H_W_C + seq = torch.arange(max(H, W, T), dtype=torch.float, device=device) uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max()) assert ( uniform_fps or B == 1 or T == 1 ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" - assert ( - H <= self.max_h and W <= self.max_w - ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})" - half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs) - half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs) + half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs) + half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs) # apply sequence scaling in temporal dimension - if fps is None: # image case - half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs) + if fps is None or self.enable_fps_modulation is False: # image case + half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs) else: - half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs) + half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs) half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1) half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1) diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py new file mode 100644 index 000000000..316117f77 --- /dev/null +++ b/comfy/ldm/cosmos/predict2.py @@ -0,0 +1,864 @@ +# original code from: https://github.com/nvidia-cosmos/cosmos-predict2 + +import torch +from torch import nn +from einops import rearrange +from einops.layers.torch import Rearrange +import logging +from typing import Callable, Optional, Tuple +import math + +from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis +from torchvision import transforms + +from comfy.ldm.modules.attention import optimized_attention + +def apply_rotary_pos_emb( + t: torch.Tensor, + freqs: torch.Tensor, +) -> torch.Tensor: + t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() + t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] + t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) + return t_out + + +# ---------------------- Feed Forward Network ----------------------- +class GPT2FeedForward(nn.Module): + def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None: + super().__init__() + self.activation = nn.GELU() + self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype) + self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype) + + self._layer_id = None + self._dim = d_model + self._hidden_dim = d_ff + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layer1(x) + + x = self.activation(x) + x = self.layer2(x) + return x + + +def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor: + """Computes multi-head attention using PyTorch's native implementation. + + This function provides a PyTorch backend alternative to Transformer Engine's attention operation. + It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product + attention, and rearranges the output back to the original format. + + The input tensor names use the following dimension conventions: + + - B: batch size + - S: sequence length + - H: number of attention heads + - D: head dimension + + Args: + q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim) + k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim) + v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim) + + Returns: + Attention output tensor with shape (batch, seq_len, n_heads * head_dim) + """ + in_q_shape = q_B_S_H_D.shape + in_k_shape = k_B_S_H_D.shape + q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) + k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True) + + +class Attention(nn.Module): + """ + A flexible attention module supporting both self-attention and cross-attention mechanisms. + + This module implements a multi-head attention layer that can operate in either self-attention + or cross-attention mode. The mode is determined by whether a context dimension is provided. + The implementation uses scaled dot-product attention and supports optional bias terms and + dropout regularization. + + Args: + query_dim (int): The dimensionality of the query vectors. + context_dim (int, optional): The dimensionality of the context (key/value) vectors. + If None, the module operates in self-attention mode using query_dim. Default: None + n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8 + head_dim (int, optional): The dimension of each attention head. Default: 64 + dropout (float, optional): Dropout probability applied to the output. Default: 0.0 + qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd" + backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine" + + Examples: + >>> # Self-attention with 512 dimensions and 8 heads + >>> self_attn = Attention(query_dim=512) + >>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim) + >>> out = self_attn(x) # (32, 16, 512) + + >>> # Cross-attention + >>> cross_attn = Attention(query_dim=512, context_dim=256) + >>> query = torch.randn(32, 16, 512) + >>> context = torch.randn(32, 8, 256) + >>> out = cross_attn(query, context) # (32, 16, 512) + """ + + def __init__( + self, + query_dim: int, + context_dim: Optional[int] = None, + n_heads: int = 8, + head_dim: int = 64, + dropout: float = 0.0, + device=None, + dtype=None, + operations=None, + ) -> None: + super().__init__() + logging.debug( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{n_heads} heads with a dimension of {head_dim}." + ) + self.is_selfattn = context_dim is None # self attention + + context_dim = query_dim if context_dim is None else context_dim + inner_dim = head_dim * n_heads + + self.n_heads = n_heads + self.head_dim = head_dim + self.query_dim = query_dim + self.context_dim = context_dim + + self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.v_norm = nn.Identity() + + self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) + self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity() + + self.attn_op = torch_attention_op + + self._query_dim = query_dim + self._context_dim = context_dim + self._inner_dim = inner_dim + + def compute_qkv( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + rope_emb: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = self.q_proj(x) + context = x if context is None else context + k = self.k_proj(context) + v = self.v_proj(context) + q, k, v = map( + lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim), + (q, k, v), + ) + + def apply_norm_and_rotary_pos_emb( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = self.q_norm(q) + k = self.k_norm(k) + v = self.v_norm(v) + if self.is_selfattn and rope_emb is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb) + k = apply_rotary_pos_emb(k, rope_emb) + return q, k, v + + q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb) + + return q, k, v + + def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + result = self.attn_op(q, k, v) # [B, S, H, D] + return self.output_dropout(self.output_proj(result)) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + rope_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (Tensor): The query tensor of shape [B, Mq, K] + context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None + """ + q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) + return self.compute_attention(q, k, v) + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor: + assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}" + timesteps = timesteps_B_T.flatten().float() + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1]) + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None): + super().__init__() + logging.debug( + f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." + ) + self.in_dim = in_features + self.out_dim = out_features + self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype) + self.activation = nn.SiLU() + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype) + else: + self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype) + + def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + emb = self.linear_1(sample) + emb = self.activation(emb) + emb = self.linear_2(emb) + + if self.use_adaln_lora: + adaln_lora_B_T_3D = emb + emb_B_T_D = sample + else: + adaln_lora_B_T_3D = None + emb_B_T_D = emb + + return emb_B_T_D, adaln_lora_B_T_3D + + +class PatchEmbed(nn.Module): + """ + PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, + depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, + making it suitable for video and image processing tasks. It supports dividing the input into patches + and embedding each patch into a vector of size `out_channels`. + + Parameters: + - spatial_patch_size (int): The size of each spatial patch. + - temporal_patch_size (int): The size of each temporal patch. + - in_channels (int): Number of input channels. Default: 3. + - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. + - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. + """ + + def __init__( + self, + spatial_patch_size: int, + temporal_patch_size: int, + in_channels: int = 3, + out_channels: int = 768, + device=None, dtype=None, operations=None + ): + super().__init__() + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + + self.proj = nn.Sequential( + Rearrange( + "b c (t r) (h m) (w n) -> b t h w (c r m n)", + r=temporal_patch_size, + m=spatial_patch_size, + n=spatial_patch_size, + ), + operations.Linear( + in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype + ), + ) + self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the PatchEmbed module. + + Parameters: + - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where + B is the batch size, + C is the number of channels, + T is the temporal dimension, + H is the height, and + W is the width of the input. + + Returns: + - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. + """ + assert x.dim() == 5 + _, _, T, H, W = x.shape + assert ( + H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 + ), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}" + assert T % self.temporal_patch_size == 0 + x = self.proj(x) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of video DiT. + """ + + def __init__( + self, + hidden_size: int, + spatial_patch_size: int, + temporal_patch_size: int, + out_channels: int, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + device=None, dtype=None, operations=None + ): + super().__init__() + self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = operations.Linear( + hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype + ) + self.hidden_size = hidden_size + self.n_adaln_chunks = 2 + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + if use_adaln_lora: + self.adaln_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype), + ) + else: + self.adaln_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype) + ) + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_T_3D is not None + shift_B_T_D, scale_B_T_D = ( + self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size] + ).chunk(2, dim=-1) + else: + shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) + + shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange( + scale_B_T_D, "b t d -> b t 1 1 d" + ) + + def _fn( + _x_B_T_H_W_D: torch.Tensor, + _norm_layer: nn.Module, + _scale_B_T_1_1_D: torch.Tensor, + _shift_B_T_1_1_D: torch.Tensor, + ) -> torch.Tensor: + return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D + + x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D) + x_B_T_H_W_O = self.linear(x_B_T_H_W_D) + return x_B_T_H_W_O + + +class Block(nn.Module): + """ + A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation. + Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation. + + Parameters: + x_dim (int): Dimension of input features + context_dim (int): Dimension of context features for cross-attention + num_heads (int): Number of attention heads + mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0 + use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False + adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256 + + The block applies the following sequence: + 1. Self-attention with AdaLN modulation + 2. Cross-attention with AdaLN modulation + 3. MLP with AdaLN modulation + + Each component uses skip connections and layer normalization. + """ + + def __init__( + self, + x_dim: int, + context_dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + self.x_dim = x_dim + self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations) + + self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.cross_attn = Attention( + x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations + ) + + self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations) + + self.use_adaln_lora = use_adaln_lora + if self.use_adaln_lora: + self.adaln_modulation_self_attn = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + self.adaln_modulation_cross_attn = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + self.adaln_modulation_mlp = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + else: + self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + crossattn_emb: torch.Tensor, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if extra_per_block_pos_emb is not None: + x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb + + if self.use_adaln_lora: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = ( + self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( + self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = ( + self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + else: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1) + + # Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting + shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d") + scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d") + gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d") + + shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d") + scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d") + gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d") + + shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d") + scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d") + gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d") + + B, T, H, W, D = x_B_T_H_W_D.shape + + def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D): + return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D + + normalized_x_B_T_H_W_D = _fn( + x_B_T_H_W_D, + self.layer_norm_self_attn, + scale_self_attn_B_T_1_1_D, + shift_self_attn_B_T_1_1_D, + ) + result_B_T_H_W_D = rearrange( + self.self_attn( + # normalized_x_B_T_HW_D, + rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + None, + rope_emb=rope_emb_L_1_1_D, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D + + def _x_fn( + _x_B_T_H_W_D: torch.Tensor, + layer_norm_cross_attn: Callable, + _scale_cross_attn_B_T_1_1_D: torch.Tensor, + _shift_cross_attn_B_T_1_1_D: torch.Tensor, + ) -> torch.Tensor: + _normalized_x_B_T_H_W_D = _fn( + _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D + ) + _result_B_T_H_W_D = rearrange( + self.cross_attn( + rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + crossattn_emb, + rope_emb=rope_emb_L_1_1_D, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + return _result_B_T_H_W_D + + result_B_T_H_W_D = _x_fn( + x_B_T_H_W_D, + self.layer_norm_cross_attn, + scale_cross_attn_B_T_1_1_D, + shift_cross_attn_B_T_1_1_D, + ) + x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D + + normalized_x_B_T_H_W_D = _fn( + x_B_T_H_W_D, + self.layer_norm_mlp, + scale_mlp_B_T_1_1_D, + shift_mlp_B_T_1_1_D, + ) + result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D) + x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D + return x_B_T_H_W_D + + +class MiniTrainDIT(nn.Module): + """ + A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1) + A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. + + Args: + max_img_h (int): Maximum height of the input images. + max_img_w (int): Maximum width of the input images. + max_frames (int): Maximum number of frames in the video sequence. + in_channels (int): Number of input channels (e.g., RGB channels for color images). + out_channels (int): Number of output channels. + patch_spatial (tuple): Spatial resolution of patches for input processing. + patch_temporal (int): Temporal resolution of patches for input processing. + concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. + model_channels (int): Base number of channels used throughout the model. + num_blocks (int): Number of transformer blocks. + num_heads (int): Number of heads in the multi-head attention layers. + mlp_ratio (float): Expansion ratio for MLP blocks. + crossattn_emb_channels (int): Number of embedding channels for cross-attention. + pos_emb_cls (str): Type of positional embeddings. + pos_emb_learnable (bool): Whether positional embeddings are learnable. + pos_emb_interpolation (str): Method for interpolating positional embeddings. + min_fps (int): Minimum frames per second. + max_fps (int): Maximum frames per second. + use_adaln_lora (bool): Whether to use AdaLN-LoRA. + adaln_lora_dim (int): Dimension for AdaLN-LoRA. + rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. + rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. + rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. + extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. + extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. + extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. + extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. + """ + + def __init__( + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: int, # tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + mlp_ratio: float = 4.0, + # cross attention settings + crossattn_emb_channels: int = 1024, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + min_fps: int = 1, + max_fps: int = 30, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = False, + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + rope_enable_fps_modulation: bool = True, + image_model=None, + device=None, + dtype=None, + operations=None, + ) -> None: + super().__init__() + self.dtype = dtype + self.max_img_h = max_img_h + self.max_img_w = max_img_w + self.max_frames = max_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.num_heads = num_heads + self.num_blocks = num_blocks + self.model_channels = model_channels + self.concat_padding_mask = concat_padding_mask + # positional embedding settings + self.pos_emb_cls = pos_emb_cls + self.pos_emb_learnable = pos_emb_learnable + self.pos_emb_interpolation = pos_emb_interpolation + self.min_fps = min_fps + self.max_fps = max_fps + self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio + self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio + self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio + self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb + self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio + self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio + self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio + self.rope_enable_fps_modulation = rope_enable_fps_modulation + + self.build_pos_embed(device=device, dtype=dtype) + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + self.t_embedder = nn.Sequential( + Timesteps(model_channels), + TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,), + ) + + in_channels = in_channels + 1 if concat_padding_mask else in_channels + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + device=device, dtype=dtype, operations=operations, + ) + + self.blocks = nn.ModuleList( + [ + Block( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + device=device, dtype=dtype, operations=operations, + ) + for _ in range(num_blocks) + ] + ) + + self.final_layer = FinalLayer( + hidden_size=self.model_channels, + spatial_patch_size=self.patch_spatial, + temporal_patch_size=self.patch_temporal, + out_channels=self.out_channels, + use_adaln_lora=self.use_adaln_lora, + adaln_lora_dim=self.adaln_lora_dim, + device=device, dtype=dtype, operations=operations, + ) + + self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype) + + def build_pos_embed(self, device=None, dtype=None) -> None: + if self.pos_emb_cls == "rope3d": + cls_type = VideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + max_fps=self.max_fps, + min_fps=self.min_fps, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + enable_fps_modulation=self.rope_enable_fps_modulation, + device=device, + ) + self.pos_embedder = cls_type( + **kwargs, # type: ignore + ) + + if self.extra_per_block_abs_pos_emb: + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + kwargs["device"] = device + kwargs["dtype"] = dtype + self.extra_pos_embedder = LearnablePosEmbAxis( + **kwargs, # type: ignore + ) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the + `self.pos_embedder` with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + if padding_mask is None: + padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device) + else: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D] + + return x_B_T_H_W_D, None, extra_pos_emb + + def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor: + x_B_C_Tt_Hp_Wp = rearrange( + x_B_T_H_W_M, + "B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + t=self.patch_temporal, + ) + return x_B_C_Tt_Hp_Wp + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + x_B_C_T_H_W = x + timesteps_B_T = timesteps + crossattn_emb = context + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + """ + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x_B_C_T_H_W, + fps=fps, + padding_mask=padding_mask, + ) + + if timesteps_B_T.ndim == 1: + timesteps_B_T = timesteps_B_T.unsqueeze(1) + t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype)) + t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D) + + # for logging purpose + affline_scale_log_info = {} + affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach() + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = t_embedding_B_T_D + self.crossattn_emb = crossattn_emb + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + assert ( + x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + ), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}" + + block_kwargs = { + "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0), + "adaln_lora_B_T_3D": adaln_lora_B_T_3D, + "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + for block in self.blocks: + x_B_T_H_W_D = block( + x_B_T_H_W_D, + t_embedding_B_T_D, + crossattn_emb, + **block_kwargs, + ) + + x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) + x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O) + return x_B_C_Tt_Hp_Wp diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py index 5322c4891..dbd2a47c0 100644 --- a/comfy/ldm/flux/controlnet.py +++ b/comfy/ldm/flux/controlnet.py @@ -121,6 +121,9 @@ class ControlNetFlux(Flux): if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + # running on sequences img img = self.img_in(img) @@ -174,7 +177,7 @@ class ControlNetFlux(Flux): out["output"] = out_output[:self.main_model_single] return out - def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): + def forward(self, x, timesteps, context, y=None, guidance=None, hint=None, **kwargs): patch_size = 2 if self.latent_input: hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size)) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index ef4ba4106..846703d52 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -101,6 +101,10 @@ class Flux(nn.Module): transformer_options={}, attn_mask: Tensor = None, ) -> Tensor: + + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -155,6 +159,9 @@ class Flux(nn.Module): if add is not None: img += add + if img.dtype == torch.float16: + img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504) + img = torch.cat((txt, img), 1) for i, block in enumerate(self.single_blocks): @@ -188,7 +195,7 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y=None, guidance=None, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape patch_size = self.patch_size x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 056e101a4..ad9a7daea 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -261,8 +261,8 @@ class CrossAttention(nn.Module): self.heads = heads self.dim_head = dim_head - self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device) - self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device) + self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 2cb77d85d..35d2270ee 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -753,7 +753,7 @@ class BasicTransformerBlock(nn.Module): for p in patch: n = p(n, extra_options) - x += n + x = n + x if "middle_patch" in transformer_patches: patch = transformer_patches["middle_patch"] for p in patch: @@ -793,12 +793,12 @@ class BasicTransformerBlock(nn.Module): for p in patch: n = p(n, extra_options) - x += n + x = n + x if self.is_res: x_skip = x x = self.ff(self.norm3(x)) if self.is_res: - x += x_skip + x = x_skip + x return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 638b04092..cb7689e84 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -34,6 +34,7 @@ import comfy.ldm.flux.model import comfy.ldm.lightricks.model import comfy.ldm.hunyuan_video.model import comfy.ldm.cosmos.model +import comfy.ldm.cosmos.predict2 import comfy.ldm.lumina.model import comfy.ldm.wan.model import comfy.ldm.hunyuan3d.model @@ -48,6 +49,7 @@ import comfy.ops from enum import Enum from . import utils import comfy.latent_formats +import comfy.model_sampling import math from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -63,38 +65,39 @@ class ModelType(Enum): V_PREDICTION_CONTINUOUS = 7 FLUX = 8 IMG_TO_IMG = 9 - - -from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV + FLOW_COSMOS = 10 def model_sampling(model_config, model_type): - s = ModelSamplingDiscrete + s = comfy.model_sampling.ModelSamplingDiscrete if model_type == ModelType.EPS: - c = EPS + c = comfy.model_sampling.EPS elif model_type == ModelType.V_PREDICTION: - c = V_PREDICTION + c = comfy.model_sampling.V_PREDICTION elif model_type == ModelType.V_PREDICTION_EDM: - c = V_PREDICTION - s = ModelSamplingContinuousEDM + c = comfy.model_sampling.V_PREDICTION + s = comfy.model_sampling.ModelSamplingContinuousEDM elif model_type == ModelType.FLOW: c = comfy.model_sampling.CONST s = comfy.model_sampling.ModelSamplingDiscreteFlow elif model_type == ModelType.STABLE_CASCADE: - c = EPS - s = StableCascadeSampling + c = comfy.model_sampling.EPS + s = comfy.model_sampling.StableCascadeSampling elif model_type == ModelType.EDM: - c = EDM - s = ModelSamplingContinuousEDM + c = comfy.model_sampling.EDM + s = comfy.model_sampling.ModelSamplingContinuousEDM elif model_type == ModelType.V_PREDICTION_CONTINUOUS: - c = V_PREDICTION - s = ModelSamplingContinuousV + c = comfy.model_sampling.V_PREDICTION + s = comfy.model_sampling.ModelSamplingContinuousV elif model_type == ModelType.FLUX: c = comfy.model_sampling.CONST s = comfy.model_sampling.ModelSamplingFlux elif model_type == ModelType.IMG_TO_IMG: c = comfy.model_sampling.IMG_TO_IMG + elif model_type == ModelType.FLOW_COSMOS: + c = comfy.model_sampling.COSMOS_RFLOW + s = comfy.model_sampling.ModelSamplingCosmosRFlow class ModelSampling(s, c): pass @@ -102,6 +105,13 @@ def model_sampling(model_config, model_type): return ModelSampling(model_config) +def convert_tensor(extra, dtype): + if hasattr(extra, "dtype"): + if extra.dtype != torch.int and extra.dtype != torch.long: + extra = extra.to(dtype) + return extra + + class BaseModel(torch.nn.Module): def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel): super().__init__() @@ -165,13 +175,13 @@ class BaseModel(torch.nn.Module): extra_conds = {} for o in kwargs: extra = kwargs[o] + if hasattr(extra, "dtype"): - if extra.dtype != torch.int and extra.dtype != torch.long: - extra = extra.to(dtype) - if isinstance(extra, list): + extra = convert_tensor(extra, dtype) + elif isinstance(extra, list): ex = [] for ext in extra: - ex.append(ext.to(dtype)) + ex.append(convert_tensor(ext, dtype)) extra = ex extra_conds[o] = extra @@ -991,6 +1001,43 @@ class CosmosVideo(BaseModel): latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image) return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5) +class CosmosPredict2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW_COSMOS, image_to_video=False, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.predict2.MiniTrainDIT) + self.image_to_video = image_to_video + if self.image_to_video: + self.concat_keys = ("mask_inverted",) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if denoise_mask is not None: + out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask) + + out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None)) + return out + + def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): + if denoise_mask is None: + return timestep + condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True) + c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1 + out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4]) + return out + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)) + sigma_noise_augmentation = 0 #TODO + if sigma_noise_augmentation != 0: + latent_image = latent_image + noise + latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image) + sigma = (sigma / (sigma + 1)) + return latent_image / (1.0 - sigma) + class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 74f539598..4aa90d3b6 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -407,6 +407,58 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["text_emb_dim"] = 2048 return dit_config + if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2 + dit_config = {} + dit_config["image_model"] = "cosmos_predict2" + dit_config["max_img_h"] = 240 + dit_config["max_img_w"] = 240 + dit_config["max_frames"] = 128 + concat_padding_mask = True + dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask) + dit_config["out_channels"] = 16 + dit_config["patch_spatial"] = 2 + dit_config["patch_temporal"] = 1 + dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0] + dit_config["concat_padding_mask"] = concat_padding_mask + dit_config["crossattn_emb_channels"] = 1024 + dit_config["pos_emb_cls"] = "rope3d" + dit_config["pos_emb_learnable"] = True + dit_config["pos_emb_interpolation"] = "crop" + dit_config["min_fps"] = 1 + dit_config["max_fps"] = 30 + + dit_config["use_adaln_lora"] = True + dit_config["adaln_lora_dim"] = 256 + if dit_config["model_channels"] == 2048: + dit_config["num_blocks"] = 28 + dit_config["num_heads"] = 16 + elif dit_config["model_channels"] == 5120: + dit_config["num_blocks"] = 36 + dit_config["num_heads"] = 40 + + if dit_config["in_channels"] == 16: + dit_config["extra_per_block_abs_pos_emb"] = False + dit_config["rope_h_extrapolation_ratio"] = 4.0 + dit_config["rope_w_extrapolation_ratio"] = 4.0 + dit_config["rope_t_extrapolation_ratio"] = 1.0 + elif dit_config["in_channels"] == 17: # img to video + if dit_config["model_channels"] == 2048: + dit_config["extra_per_block_abs_pos_emb"] = False + dit_config["rope_h_extrapolation_ratio"] = 3.0 + dit_config["rope_w_extrapolation_ratio"] = 3.0 + dit_config["rope_t_extrapolation_ratio"] = 1.0 + elif dit_config["model_channels"] == 5120: + dit_config["rope_h_extrapolation_ratio"] = 2.0 + dit_config["rope_w_extrapolation_ratio"] = 2.0 + dit_config["rope_t_extrapolation_ratio"] = 0.8333333333333334 + + dit_config["extra_h_extrapolation_ratio"] = 1.0 + dit_config["extra_w_extrapolation_ratio"] = 1.0 + dit_config["extra_t_extrapolation_ratio"] = 1.0 + dit_config["rope_enable_fps_modulation"] = False + + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/model_management.py b/comfy/model_management.py index 2e06e884d..6a61c4d9c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -319,6 +319,7 @@ except: pass +SUPPORT_FP8_OPS = args.supports_fp8_compute try: if is_amd(): try: @@ -329,9 +330,13 @@ try: logging.info("AMD arch: {}".format(arch)) logging.info("ROCm version: {}".format(rocm_version)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much - if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches + if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950 ENABLE_PYTORCH_ATTENTION = True + if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): + if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches + SUPPORT_FP8_OPS = True + except: pass @@ -352,7 +357,7 @@ except: pass try: - if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5: + if torch_version_numeric >= (2, 5): torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) except: logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp") @@ -1075,7 +1080,7 @@ def pytorch_attention_flash_attention(): global ENABLE_PYTORCH_ATTENTION if ENABLE_PYTORCH_ATTENTION: #TODO: more reliable way of checking for flash attention? - if is_nvidia(): #pytorch flash attention only works on Nvidia + if is_nvidia(): return True if is_intel_xpu(): return True @@ -1091,7 +1096,7 @@ def force_upcast_attention_dtype(): upcast = args.force_upcast_attention macos_version = mac_version() - if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS + if macos_version is not None and ((14, 5) <= macos_version): # black image bug on recent versions of macOS, I don't think it's ever getting fixed upcast = True if upcast: @@ -1290,7 +1295,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False def supports_fp8_compute(device=None): - if args.supports_fp8_compute: + if SUPPORT_FP8_OPS: return True if not is_nvidia(): @@ -1304,11 +1309,11 @@ def supports_fp8_compute(device=None): if props.minor < 9: return False - if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3): + if torch_version_numeric < (2, 3): return False if WINDOWS: - if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4): + if torch_version_numeric < (2, 4): return False return True diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index deb8af327..7cc62c4ed 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -17,23 +17,26 @@ """ from __future__ import annotations -from typing import Optional, Callable -import torch + +import collections import copy import inspect import logging -import uuid -import collections import math +import uuid +from typing import Callable, Optional + +import torch -import comfy.utils import comfy.float -import comfy.model_management -import comfy.lora import comfy.hooks +import comfy.lora +import comfy.model_management import comfy.patcher_extension -from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection +import comfy.utils from comfy.comfy_types import UnetWrapperFunction +from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP + def string_to_seed(data): crc = 0xFFFFFFFF diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 7e7291476..b240b7f29 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -77,6 +77,25 @@ class IMG_TO_IMG(X0): def calculate_input(self, sigma, noise): return noise +class COSMOS_RFLOW: + def calculate_input(self, sigma, noise): + sigma = (sigma / (sigma + 1)) + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + return noise * (1.0 - sigma) + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = (sigma / (sigma + 1)) + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input * (1.0 - sigma) - model_output * sigma + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + noise = noise * sigma + noise += latent_image + return noise + + def inverse_noise_scaling(self, sigma, latent): + return latent class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None, zsnr=None): @@ -350,3 +369,15 @@ class ModelSamplingFlux(torch.nn.Module): if percent >= 1.0: return 0.0 return flux_time_shift(self.shift, 1.0, 1.0 - percent) + + +class ModelSamplingCosmosRFlow(ModelSamplingContinuousEDM): + def timestep(self, sigma): + return sigma / (sigma + 1) + + def sigma(self, timestep): + sigma_max = self.sigma_max + if timestep >= (sigma_max / (sigma_max + 1)): + return sigma_max + + return timestep / (1 - timestep) diff --git a/comfy/sd.py b/comfy/sd.py index e98a3aa87..cd13ab5f0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1081,7 +1081,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format +def load_diffusion_model_state_dict(sd, model_options={}): + """ + Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. + + Args: + sd (dict): State dictionary containing model weights and configuration + model_options (dict, optional): Additional options for model loading. Supports: + - dtype: Override model data type + - custom_operations: Custom model operations + - fp8_optimizations: Enable FP8 optimizations + + Returns: + ModelPatcher: A wrapped model instance that handles device management and weight loading. + Returns None if the model configuration cannot be detected. + + The function: + 1. Detects and handles different model formats (regular, diffusers, mmdit) + 2. Configures model dtype based on parameters and device capabilities + 3. Handles weight conversion and device placement + 4. Manages model optimization settings + 5. Loads weights and returns a device-managed model instance + """ dtype = model_options.get("dtype", None) #Allow loading unets from checkpoint files diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index ac61babe9..1b69a4103 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -462,7 +462,7 @@ class SDTokenizer: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length) - self.min_length = min_length + self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length) self.end_token = None self.min_padding = min_padding diff --git a/comfy/supported_models.py b/comfy/supported_models.py index efe2e6b8f..19f25e337 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -908,6 +908,48 @@ class CosmosI2V(CosmosT2V): out = model_base.CosmosVideo(self, image_to_video=True, device=device) return out +class CosmosT2IPredict2(supported_models_base.BASE): + unet_config = { + "image_model": "cosmos_predict2", + "in_channels": 16, + } + + sampling_settings = { + "sigma_data": 1.0, + "sigma_max": 80.0, + "sigma_min": 0.002, + } + + unet_extra_config = {} + latent_format = latent_formats.Wan21 + + memory_usage_factor = 1.0 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.CosmosPredict2(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect)) + +class CosmosI2VPredict2(CosmosT2IPredict2): + unet_config = { + "image_model": "cosmos_predict2", + "in_channels": 17, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.CosmosPredict2(self, image_to_video=True, device=device) + return out + class Lumina2(supported_models_base.BASE): unet_config = { "image_model": "lumina2", @@ -1139,6 +1181,6 @@ class ACEStep(supported_models_base.BASE): def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep] models += [SVD_img2vid] diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py index d2a1d0151..560b82be3 100644 --- a/comfy/weight_adapter/__init__.py +++ b/comfy/weight_adapter/__init__.py @@ -1,4 +1,4 @@ -from .base import WeightAdapterBase +from .base import WeightAdapterBase, WeightAdapterTrainBase from .lora import LoRAAdapter from .loha import LoHaAdapter from .lokr import LoKrAdapter @@ -15,3 +15,9 @@ adapters: list[type[WeightAdapterBase]] = [ OFTAdapter, BOFTAdapter, ] + +__all__ = [ + "WeightAdapterBase", + "WeightAdapterTrainBase", + "adapters" +] + [a.__name__ for a in adapters] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index 29873519d..b5c7db423 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -12,12 +12,20 @@ class WeightAdapterBase: weights: list[torch.Tensor] @classmethod - def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]: + def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]: raise NotImplementedError def to_train(self) -> "WeightAdapterTrainBase": raise NotImplementedError + @classmethod + def create_train(cls, weight, *args) -> "WeightAdapterTrainBase": + """ + weight: The original weight tensor to be modified. + *args: Additional arguments for configuration, such as rank, alpha etc. + """ + raise NotImplementedError + def calculate_weight( self, weight, @@ -33,10 +41,22 @@ class WeightAdapterBase: class WeightAdapterTrainBase(nn.Module): + # We follow the scheme of PR #7032 def __init__(self): super().__init__() - # [TODO] Collaborate with LoRA training PR #7032 + def __call__(self, w): + """ + w: The original weight tensor to be modified. + """ + raise NotImplementedError + + def passive_memory_usage(self): + raise NotImplementedError("passive_memory_usage is not implemented") + + def move_to(self, device): + self.to(device) + return self.passive_memory_usage() def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): @@ -102,3 +122,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten padded_tensor[new_slices] = tensor[orig_slices] return padded_tensor + + +def tucker_weight_from_conv(up, down, mid): + up = up.reshape(up.size(0), up.size(1)) + down = down.reshape(down.size(0), down.size(1)) + return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down) + + +def tucker_weight(wa, wb, t): + temp = torch.einsum("i j ..., j r -> i r ...", t, wb) + return torch.einsum("i j ..., i r -> r j ...", temp, wa) diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index b2e623924..729dbd9e6 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -3,7 +3,56 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape +from .base import ( + WeightAdapterBase, + WeightAdapterTrainBase, + weight_decompose, + pad_tensor_to_shape, + tucker_weight_from_conv, +) + + +class LoraDiff(WeightAdapterTrainBase): + def __init__(self, weights): + super().__init__() + mat1, mat2, alpha, mid, dora_scale, reshape = weights + out_dim, rank = mat1.shape[0], mat1.shape[1] + rank, in_dim = mat2.shape[0], mat2.shape[1] + if mid is not None: + convdim = mid.ndim - 2 + layer = ( + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d + )[convdim] + else: + layer = torch.nn.Linear + self.lora_up = layer(rank, out_dim, bias=False) + self.lora_down = layer(in_dim, rank, bias=False) + self.lora_up.weight.data.copy_(mat1) + self.lora_down.weight.data.copy_(mat2) + if mid is not None: + self.lora_mid = layer(mid, rank, bias=False) + self.lora_mid.weight.data.copy_(mid) + else: + self.lora_mid = None + self.rank = rank + self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) + + def __call__(self, w): + org_dtype = w.dtype + if self.lora_mid is None: + diff = self.lora_up.weight @ self.lora_down.weight + else: + diff = tucker_weight_from_conv( + self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight + ) + scale = self.alpha / self.rank + weight = w + scale * diff.reshape(w.shape) + return weight.to(org_dtype) + + def passive_memory_usage(self): + return sum(param.numel() * param.element_size() for param in self.parameters()) class LoRAAdapter(WeightAdapterBase): @@ -13,6 +62,21 @@ class LoRAAdapter(WeightAdapterBase): self.loaded_keys = loaded_keys self.weights = weights + @classmethod + def create_train(cls, weight, rank=1, alpha=1.0): + out_dim = weight.shape[0] + in_dim = weight.shape[1:].numel() + mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) + mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) + torch.nn.init.constant_(mat2, 0.0) + return LoraDiff( + (mat1, mat2, alpha, None, None, None) + ) + + def to_train(self): + return LoraDiff(self.weights) + @classmethod def load( cls, diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 010564704..d93fbd778 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -346,20 +346,6 @@ class FluxKontextProImageNode(ComfyNodeABC): }, } - @classmethod - def VALIDATE_INPUTS(cls, aspect_ratio: str): - try: - validate_aspect_ratio( - aspect_ratio, - minimum_ratio=cls.MINIMUM_RATIO, - maximum_ratio=cls.MAXIMUM_RATIO, - minimum_ratio_str=cls.MINIMUM_RATIO_STR, - maximum_ratio_str=cls.MAXIMUM_RATIO_STR, - ) - except Exception as e: - return str(e) - return True - RETURN_TYPES = (IO.IMAGE,) DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value FUNCTION = "api_call" @@ -380,6 +366,13 @@ class FluxKontextProImageNode(ComfyNodeABC): unique_id: Union[str, None] = None, **kwargs, ): + aspect_ratio = validate_aspect_ratio( + aspect_ratio, + minimum_ratio=self.MINIMUM_RATIO, + maximum_ratio=self.MAXIMUM_RATIO, + minimum_ratio_str=self.MINIMUM_RATIO_STR, + maximum_ratio_str=self.MAXIMUM_RATIO_STR, + ) if input_image is None: validate_string(prompt, strip_whitespace=False) operation = SynchronousOperation( @@ -395,13 +388,7 @@ class FluxKontextProImageNode(ComfyNodeABC): guidance=round(guidance, 1), steps=steps, seed=seed, - aspect_ratio=validate_aspect_ratio( - aspect_ratio, - minimum_ratio=self.MINIMUM_RATIO, - maximum_ratio=self.MAXIMUM_RATIO, - minimum_ratio_str=self.MINIMUM_RATIO_STR, - maximum_ratio_str=self.MAXIMUM_RATIO_STR, - ), + aspect_ratio=aspect_ratio, input_image=( input_image if input_image is None diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index b1cbf511d..b8487355f 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -324,7 +324,7 @@ class IdeogramV1(ComfyNodeABC): RETURN_TYPES = (IO.IMAGE,) FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram/v1" + CATEGORY = "api node/image/Ideogram" DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True @@ -483,7 +483,7 @@ class IdeogramV2(ComfyNodeABC): RETURN_TYPES = (IO.IMAGE,) FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram/v2" + CATEGORY = "api node/image/Ideogram" DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True @@ -649,7 +649,7 @@ class IdeogramV3(ComfyNodeABC): RETURN_TYPES = (IO.IMAGE,) FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram/v3" + CATEGORY = "api node/image/Ideogram" DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True diff --git a/comfy_config/config_parser.py b/comfy_config/config_parser.py new file mode 100644 index 000000000..a9cbd94dd --- /dev/null +++ b/comfy_config/config_parser.py @@ -0,0 +1,97 @@ +import os +from pathlib import Path +from typing import Optional + +from pydantic_settings import PydanticBaseSettingsSource, TomlConfigSettingsSource + +from comfy_config.types import ( + ComfyConfig, + ProjectConfig, + PyProjectConfig, + PyProjectSettings +) + +""" +Extract configuration from a custom node directory's pyproject.toml file or a Python file. + +This function reads and parses the pyproject.toml file in the specified directory +to extract project and ComfyUI-specific configuration information. If no +pyproject.toml file is found, it creates a minimal configuration using the +folder name as the project name. If a Python file is provided, it uses the +file name (without extension) as the project name. + +Args: + path (str): Path to the directory containing the pyproject.toml file, or + path to a .py file. If pyproject.toml doesn't exist in a directory, + the folder name will be used as the default project name. If a .py + file is provided, the filename (without .py extension) will be used + as the project name. + +Returns: + Optional[PyProjectConfig]: A PyProjectConfig object containing: + - project: Basic project information (name, version, dependencies, etc.) + - tool_comfy: ComfyUI-specific configuration (publisher_id, models, etc.) + Returns None if configuration extraction fails or if the provided file + is not a Python file. + +Notes: + - If pyproject.toml is missing in a directory, creates a default config with folder name + - If a .py file is provided, creates a default config with filename (without extension) + - Returns None for non-Python files + +Example: + >>> from comfy_config import config_parser + >>> # For directory + >>> custom_node_dir = os.path.dirname(os.path.realpath(__file__)) + >>> project_config = config_parser.extract_node_configuration(custom_node_dir) + >>> print(project_config.project.name) # "my_custom_node" or name from pyproject.toml + >>> + >>> # For single-file Python node file + >>> py_file_path = os.path.realpath(__file__) # "/path/to/my_node.py" + >>> project_config = config_parser.extract_node_configuration(py_file_path) + >>> print(project_config.project.name) # "my_node" +""" +def extract_node_configuration(path) -> Optional[PyProjectConfig]: + if os.path.isfile(path): + file_path = Path(path) + + if file_path.suffix.lower() != '.py': + return None + + project_name = file_path.stem + project = ProjectConfig(name=project_name) + comfy = ComfyConfig() + return PyProjectConfig(project=project, tool_comfy=comfy) + + folder_name = os.path.basename(path) + toml_path = Path(path) / "pyproject.toml" + + if not toml_path.exists(): + project = ProjectConfig(name=folder_name) + comfy = ComfyConfig() + return PyProjectConfig(project=project, tool_comfy=comfy) + + raw_settings = load_pyproject_settings(toml_path) + + project_data = raw_settings.project + + tool_data = raw_settings.tool + comfy_data = tool_data.get("comfy", {}) if tool_data else {} + + return PyProjectConfig(project=project_data, tool_comfy=comfy_data) + + +def load_pyproject_settings(toml_path: Path) -> PyProjectSettings: + class PyProjectLoader(PyProjectSettings): + @classmethod + def settings_customise_sources( + cls, + settings_cls, + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ): + return (TomlConfigSettingsSource(settings_cls, toml_path),) + + return PyProjectLoader() diff --git a/comfy_config/types.py b/comfy_config/types.py new file mode 100644 index 000000000..5222cc59b --- /dev/null +++ b/comfy_config/types.py @@ -0,0 +1,93 @@ +from pydantic import BaseModel, Field, field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict +from typing import List, Optional + +# IMPORTANT: The type definitions specified in pyproject.toml for custom nodes +# must remain synchronized with the corresponding files in the https://github.com/Comfy-Org/comfy-cli/blob/main/comfy_cli/registry/types.py. +# Any changes to one must be reflected in the other to maintain consistency. + +class NodeVersion(BaseModel): + changelog: str + dependencies: List[str] + deprecated: bool + id: str + version: str + download_url: str + + +class Node(BaseModel): + id: str + name: str + description: str + author: Optional[str] = None + license: Optional[str] = None + icon: Optional[str] = None + repository: Optional[str] = None + tags: List[str] = Field(default_factory=list) + latest_version: Optional[NodeVersion] = None + + +class PublishNodeVersionResponse(BaseModel): + node_version: NodeVersion + signedUrl: str + + +class URLs(BaseModel): + homepage: str = Field(default="", alias="Homepage") + documentation: str = Field(default="", alias="Documentation") + repository: str = Field(default="", alias="Repository") + issues: str = Field(default="", alias="Issues") + + +class Model(BaseModel): + location: str + model_url: str + + +class ComfyConfig(BaseModel): + publisher_id: str = Field(default="", alias="PublisherId") + display_name: str = Field(default="", alias="DisplayName") + icon: str = Field(default="", alias="Icon") + models: List[Model] = Field(default_factory=list, alias="Models") + includes: List[str] = Field(default_factory=list) + web: Optional[str] = None + + +class License(BaseModel): + file: str = "" + text: str = "" + + +class ProjectConfig(BaseModel): + name: str = "" + description: str = "" + version: str = "1.0.0" + requires_python: str = Field(default=">= 3.9", alias="requires-python") + dependencies: List[str] = Field(default_factory=list) + license: License = Field(default_factory=License) + urls: URLs = Field(default_factory=URLs) + + @field_validator('license', mode='before') + @classmethod + def validate_license(cls, v): + if isinstance(v, str): + return License(text=v) + elif isinstance(v, dict): + return License(**v) + elif isinstance(v, License): + return v + else: + return License() + + +class PyProjectConfig(BaseModel): + project: ProjectConfig = Field(default_factory=ProjectConfig) + tool_comfy: ComfyConfig = Field(default_factory=ComfyConfig) + + +class PyProjectSettings(BaseSettings): + project: dict = Field(default_factory=dict) + + tool: dict = Field(default_factory=dict) + + model_config = SettingsConfigDict(extra='allow') diff --git a/comfy_extras/nodes_cosmos.py b/comfy_extras/nodes_cosmos.py index bd35ddb06..4f4960551 100644 --- a/comfy_extras/nodes_cosmos.py +++ b/comfy_extras/nodes_cosmos.py @@ -2,6 +2,7 @@ import nodes import torch import comfy.model_management import comfy.utils +import comfy.latent_formats class EmptyCosmosLatentVideo: @@ -75,8 +76,53 @@ class CosmosImageToVideoLatent: out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) return (out_latent,) +class CosmosPredict2ImageToVideoLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": {"vae": ("VAE", ), + "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"start_image": ("IMAGE", ), + "end_image": ("IMAGE", ), + }} + + + RETURN_TYPES = ("LATENT",) + FUNCTION = "encode" + + CATEGORY = "conditioning/inpaint" + + def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None): + latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is None and end_image is None: + out_latent = {} + out_latent["samples"] = latent + return (out_latent,) + + mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) + + if start_image is not None: + latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1) + latent[:, :, :latent_temp.shape[-3]] = latent_temp + mask[:, :, :latent_temp.shape[-3]] *= 0.0 + + if end_image is not None: + latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0) + latent[:, :, -latent_temp.shape[-3]:] = latent_temp + mask[:, :, -latent_temp.shape[-3]:] *= 0.0 + + out_latent = {} + latent_format = comfy.latent_formats.Wan21() + latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask) + out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) + out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) + return (out_latent,) NODE_CLASS_MAPPINGS = { "EmptyCosmosLatentVideo": EmptyCosmosLatentVideo, "CosmosImageToVideoLatent": CosmosImageToVideoLatent, + "CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent, } diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 6ebf1dbd8..b1e0d4666 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -16,7 +16,8 @@ from inspect import cleandoc import torch import comfy.utils -from comfy.comfy_types import FileLocator +from comfy.comfy_types import FileLocator, IO +from server import PromptServer MAX_RESOLUTION = nodes.MAX_RESOLUTION @@ -491,6 +492,37 @@ class SaveSVGNode: counter += 1 return { "ui": { "images": results } } +class GetImageSize: + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": (IO.IMAGE,), + }, + "hidden": { + "unique_id": "UNIQUE_ID", + } + } + + RETURN_TYPES = (IO.INT, IO.INT, IO.INT) + RETURN_NAMES = ("width", "height", "batch_size") + FUNCTION = "get_size" + + CATEGORY = "image" + DESCRIPTION = """Returns width and height of the image, and passes it through unchanged.""" + + def get_size(self, image, unique_id=None) -> tuple[int, int]: + height = image.shape[1] + width = image.shape[2] + batch_size = image.shape[0] + + # Send progress text to display size on the node + if unique_id: + PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id) + + return width, height, batch_size + NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, "RepeatImageBatch": RepeatImageBatch, @@ -500,4 +532,5 @@ NODE_CLASS_MAPPINGS = { "SaveAnimatedPNG": SaveAnimatedPNG, "SaveSVGNode": SaveSVGNode, "ImageStitch": ImageStitch, + "GetImageSize": GetImageSize, } diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 71a652ffa..ae5d2c563 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -189,7 +189,7 @@ class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps"],), + "sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps", "cosmos_rflow"],), "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), }} @@ -202,6 +202,7 @@ class ModelSamplingContinuousEDM: def patch(self, model, sampling, sigma_max, sigma_min): m = model.clone() + sampling_base = comfy.model_sampling.ModelSamplingContinuousEDM latent_format = None sigma_data = 1.0 if sampling == "eps": @@ -215,8 +216,11 @@ class ModelSamplingContinuousEDM: sampling_type = comfy.model_sampling.EDM sigma_data = 0.5 latent_format = comfy.latent_formats.SDXL_Playground_2_5() + elif sampling == "cosmos_rflow": + sampling_type = comfy.model_sampling.COSMOS_RFLOW + sampling_base = comfy.model_sampling.ModelSamplingCosmosRFlow - class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): + class ModelSamplingAdvanced(sampling_base, sampling_type): pass model_sampling = ModelSamplingAdvanced(model.model.model_config) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py new file mode 100644 index 000000000..fbff01010 --- /dev/null +++ b/comfy_extras/nodes_train.py @@ -0,0 +1,709 @@ +import datetime +import json +import logging +import os + +import numpy as np +import safetensors +import torch +from PIL import Image, ImageDraw, ImageFont +from PIL.PngImagePlugin import PngInfo +import torch.utils.checkpoint +import tqdm + +import comfy.samplers +import comfy.sd +import comfy.utils +import comfy.model_management +import comfy_extras.nodes_custom_sampler +import folder_paths +import node_helpers +from comfy.cli_args import args +from comfy.comfy_types.node_typing import IO +from comfy.weight_adapter import adapters + + +class TrainSampler(comfy.samplers.Sampler): + + def __init__(self, loss_fn, optimizer, loss_callback=None): + self.loss_fn = loss_fn + self.optimizer = optimizer + self.loss_callback = loss_callback + + def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + self.optimizer.zero_grad() + noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False) + latent = model_wrap.inner_model.model_sampling.noise_scaling( + torch.zeros_like(sigmas), + torch.zeros_like(noise, requires_grad=True), + latent_image, + False + ) + + # Ensure model is in training mode and computing gradients + # x0 pred + denoised = model_wrap(noise, sigmas, **extra_args) + try: + loss = self.loss_fn(denoised, latent.clone()) + except RuntimeError as e: + if "does not require grad and does not have a grad_fn" in str(e): + logging.info("WARNING: This is likely due to the model is loaded in inference mode.") + loss.backward() + if self.loss_callback: + self.loss_callback(loss.item()) + + self.optimizer.step() + # torch.cuda.memory._dump_snapshot("trainn.pickle") + # torch.cuda.memory._record_memory_history(enabled=None) + return torch.zeros_like(latent_image) + + +class BiasDiff(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.bias = bias + + def __call__(self, b): + org_dtype = b.dtype + return (b.to(self.bias) + self.bias).to(org_dtype) + + def passive_memory_usage(self): + return self.bias.nelement() * self.bias.element_size() + + def move_to(self, device): + self.to(device=device) + return self.passive_memory_usage() + + +def load_and_process_images(image_files, input_dir, resize_method="None"): + """Utility function to load and process a list of images. + + Args: + image_files: List of image filenames + input_dir: Base directory containing the images + resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") + + Returns: + torch.Tensor: Batch of processed images + """ + if not image_files: + raise ValueError("No valid images found in input") + + output_images = [] + w, h = None, None + + for file in image_files: + image_path = os.path.join(input_dir, file) + img = node_helpers.pillow(Image.open, image_path) + + if img.mode == "I": + img = img.point(lambda i: i * (1 / 255)) + img = img.convert("RGB") + + if w is None and h is None: + w, h = img.size[0], img.size[1] + + # Resize image to first image + if img.size[0] != w or img.size[1] != h: + if resize_method == "Stretch": + img = img.resize((w, h), Image.Resampling.LANCZOS) + elif resize_method == "Crop": + img = img.crop((0, 0, w, h)) + elif resize_method == "Pad": + img = img.resize((w, h), Image.Resampling.LANCZOS) + elif resize_method == "None": + raise ValueError( + "Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images." + ) + + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] + output_images.append(img_tensor) + + return torch.cat(output_images, dim=0) + + +class LoadImageSetNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "images": ( + [ + f + for f in os.listdir(folder_paths.get_input_directory()) + if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff")) + ], + {"image_upload": True, "allow_batch": True}, + ) + }, + "optional": { + "resize_method": ( + ["None", "Stretch", "Crop", "Pad"], + {"default": "None"}, + ), + }, + } + + INPUT_IS_LIST = True + RETURN_TYPES = ("IMAGE",) + FUNCTION = "load_images" + CATEGORY = "loaders" + EXPERIMENTAL = True + DESCRIPTION = "Loads a batch of images from a directory for training." + + @classmethod + def VALIDATE_INPUTS(s, images, resize_method): + filenames = images[0] if isinstance(images[0], list) else images + + for image in filenames: + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) + return True + + def load_images(self, input_files, resize_method): + input_dir = folder_paths.get_input_directory() + valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"] + image_files = [ + f + for f in input_files + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, input_dir, resize_method) + return (output_tensor,) + + +class LoadImageSetFromFolderNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}) + }, + "optional": { + "resize_method": ( + ["None", "Stretch", "Crop", "Pad"], + {"default": "None"}, + ), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "load_images" + CATEGORY = "loaders" + EXPERIMENTAL = True + DESCRIPTION = "Loads a batch of images from a directory for training." + + def load_images(self, folder, resize_method): + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + image_files = [ + f + for f in os.listdir(sub_input_dir) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method) + return (output_tensor,) + + +def draw_loss_graph(loss_map, steps): + width, height = 500, 300 + img = Image.new("RGB", (width, height), "white") + draw = ImageDraw.Draw(img) + + min_loss, max_loss = min(loss_map.values()), max(loss_map.values()) + scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()] + + prev_point = (0, height - int(scaled_loss[0] * height)) + for i, l in enumerate(scaled_loss[1:], start=1): + x = int(i / (steps - 1) * width) + y = height - int(l * height) + draw.line([prev_point, (x, y)], fill="blue", width=2) + prev_point = (x, y) + + return img + + +def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None): + if result is None: + result = [] + elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): + result.append(model) + logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") + return result + name = name or "root" + for next_name, child in model.named_children(): + find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}") + return result + + +def patch(m): + if not hasattr(m, "forward"): + return + org_forward = m.forward + def fwd(args, kwargs): + return org_forward(*args, **kwargs) + def checkpointing_fwd(*args, **kwargs): + return torch.utils.checkpoint.checkpoint( + fwd, args, kwargs, use_reentrant=False + ) + m.org_forward = org_forward + m.forward = checkpointing_fwd + + +def unpatch(m): + if hasattr(m, "org_forward"): + m.forward = m.org_forward + del m.org_forward + + +class TrainLoraNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}), + "latents": ( + "LATENT", + { + "tooltip": "The Latents to use for training, serve as dataset/input of the model." + }, + ), + "positive": ( + IO.CONDITIONING, + {"tooltip": "The positive conditioning to use for training."}, + ), + "batch_size": ( + IO.INT, + { + "default": 1, + "min": 1, + "max": 10000, + "step": 1, + "tooltip": "The batch size to use for training.", + }, + ), + "steps": ( + IO.INT, + { + "default": 16, + "min": 1, + "max": 100000, + "tooltip": "The number of steps to train the LoRA for.", + }, + ), + "learning_rate": ( + IO.FLOAT, + { + "default": 0.0005, + "min": 0.0000001, + "max": 1.0, + "step": 0.000001, + "tooltip": "The learning rate to use for training.", + }, + ), + "rank": ( + IO.INT, + { + "default": 8, + "min": 1, + "max": 128, + "tooltip": "The rank of the LoRA layers.", + }, + ), + "optimizer": ( + ["AdamW", "Adam", "SGD", "RMSprop"], + { + "default": "AdamW", + "tooltip": "The optimizer to use for training.", + }, + ), + "loss_function": ( + ["MSE", "L1", "Huber", "SmoothL1"], + { + "default": "MSE", + "tooltip": "The loss function to use for training.", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", + }, + ), + "training_dtype": ( + ["bf16", "fp32"], + {"default": "bf16", "tooltip": "The dtype to use for training."}, + ), + "lora_dtype": ( + ["bf16", "fp32"], + {"default": "bf16", "tooltip": "The dtype to use for lora."}, + ), + "existing_lora": ( + folder_paths.get_filename_list("loras") + ["[None]"], + { + "default": "[None]", + "tooltip": "The existing LoRA to append to. Set to None for new LoRA.", + }, + ), + }, + } + + RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT) + RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps") + FUNCTION = "train" + CATEGORY = "training" + EXPERIMENTAL = True + + def train( + self, + model, + latents, + positive, + batch_size, + steps, + learning_rate, + rank, + optimizer, + loss_function, + seed, + training_dtype, + lora_dtype, + existing_lora, + ): + mp = model.clone() + dtype = node_helpers.string_to_torch_dtype(training_dtype) + lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) + mp.set_model_compute_dtype(dtype) + + latents = latents["samples"].to(dtype) + num_images = latents.shape[0] + + with torch.inference_mode(False): + lora_sd = {} + generator = torch.Generator() + generator.manual_seed(seed) + + # Load existing LoRA weights if provided + existing_weights = {} + existing_steps = 0 + if existing_lora != "[None]": + lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora) + # Extract steps from filename like "trained_lora_10_steps_20250225_203716" + existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1]) + if lora_path: + existing_weights = comfy.utils.load_torch_file(lora_path) + + all_weight_adapters = [] + for n, m in mp.model.named_modules(): + if hasattr(m, "weight_function"): + if m.weight is not None: + key = "{}.weight".format(n) + shape = m.weight.shape + if len(shape) >= 2: + alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) + dora_scale = existing_weights.get( + f"{key}.dora_scale", None + ) + for adapter_cls in adapters: + existing_adapter = adapter_cls.load( + n, existing_weights, alpha, dora_scale + ) + if existing_adapter is not None: + break + else: + # If no existing adapter found, use LoRA + # We will add algo option in the future + existing_adapter = None + adapter_cls = adapters[0] + + if existing_adapter is not None: + train_adapter = existing_adapter.to_train().to(lora_dtype) + else: + # Use LoRA with alpha=1.0 by default + train_adapter = adapter_cls.create_train( + m.weight, rank=rank, alpha=1.0 + ).to(lora_dtype) + for name, parameter in train_adapter.named_parameters(): + lora_sd[f"{n}.{name}"] = parameter + + mp.add_weight_wrapper(key, train_adapter) + all_weight_adapters.append(train_adapter) + else: + diff = torch.nn.Parameter( + torch.zeros( + m.weight.shape, dtype=lora_dtype, requires_grad=True + ) + ) + diff_module = BiasDiff(diff) + mp.add_weight_wrapper(key, BiasDiff(diff)) + all_weight_adapters.append(diff_module) + lora_sd["{}.diff".format(n)] = diff + if hasattr(m, "bias") and m.bias is not None: + key = "{}.bias".format(n) + bias = torch.nn.Parameter( + torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True) + ) + bias_module = BiasDiff(bias) + lora_sd["{}.diff_b".format(n)] = bias + mp.add_weight_wrapper(key, BiasDiff(bias)) + all_weight_adapters.append(bias_module) + + if optimizer == "Adam": + optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate) + elif optimizer == "AdamW": + optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate) + elif optimizer == "SGD": + optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate) + elif optimizer == "RMSprop": + optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate) + + # Setup loss function based on selection + if loss_function == "MSE": + criterion = torch.nn.MSELoss() + elif loss_function == "L1": + criterion = torch.nn.L1Loss() + elif loss_function == "Huber": + criterion = torch.nn.HuberLoss() + elif loss_function == "SmoothL1": + criterion = torch.nn.SmoothL1Loss() + + # setup models + for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): + patch(m) + comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) + + # Setup sampler and guider like in test script + loss_map = {"loss": []} + def loss_callback(loss): + loss_map["loss"].append(loss) + pbar.set_postfix({"loss": f"{loss:.4f}"}) + train_sampler = TrainSampler( + criterion, optimizer, loss_callback=loss_callback + ) + guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) + guider.set_conds(positive) # Set conditioning from input + ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced() + + # yoland: this currently resize to the first image in the dataset + + # Training loop + torch.cuda.empty_cache() + try: + for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): + # Generate random sigma + sigma = mp.model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + sigma = torch.tensor([sigma]) + + noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed) + + indices = torch.randperm(num_images)[:batch_size] + ss.sample( + noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()} + ) + finally: + for m in mp.model.modules(): + unpatch(m) + del ss, train_sampler, optimizer + torch.cuda.empty_cache() + + for adapter in all_weight_adapters: + adapter.requires_grad_(False) + + for param in lora_sd: + lora_sd[param] = lora_sd[param].to(lora_dtype) + + return (mp, lora_sd, loss_map, steps + existing_steps) + + +class LoraModelLoader: + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), + "lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), + } + } + + RETURN_TYPES = ("MODEL",) + OUTPUT_TOOLTIPS = ("The modified diffusion model.",) + FUNCTION = "load_lora_model" + + CATEGORY = "loaders" + DESCRIPTION = "Load Trained LoRA weights from Train LoRA node." + EXPERIMENTAL = True + + def load_lora_model(self, model, lora, strength_model): + if strength_model == 0: + return (model, ) + + model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) + return (model_lora, ) + + +class SaveLoRA: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "lora": ( + IO.LORA_MODEL, + { + "tooltip": "The LoRA model to save. Do not use the model with LoRA layers." + }, + ), + "prefix": ( + "STRING", + { + "default": "loras/ComfyUI_trained_lora", + "tooltip": "The prefix to use for the saved LoRA file.", + }, + ), + }, + "optional": { + "steps": ( + IO.INT, + { + "forceInput": True, + "tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.", + }, + ), + }, + } + + RETURN_TYPES = () + FUNCTION = "save" + CATEGORY = "loaders" + EXPERIMENTAL = True + OUTPUT_NODE = True + + def save(self, lora, prefix, steps=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir) + if steps is None: + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + else: + output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + safetensors.torch.save_file(lora, output_checkpoint) + return {} + + +class LossGraphNode: + def __init__(self): + self.output_dir = folder_paths.get_temp_directory() + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "loss": (IO.LOSS_MAP, {"default": {}}), + "filename_prefix": (IO.STRING, {"default": "loss_graph"}), + }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "plot_loss" + OUTPUT_NODE = True + CATEGORY = "training" + EXPERIMENTAL = True + DESCRIPTION = "Plots the loss graph and saves it to the output directory." + + def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None): + loss_values = loss["loss"] + width, height = 800, 480 + margin = 40 + + img = Image.new( + "RGB", (width + margin, height + margin), "white" + ) # Extend canvas + draw = ImageDraw.Draw(img) + + min_loss, max_loss = min(loss_values), max(loss_values) + scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values] + + steps = len(loss_values) + + prev_point = (margin, height - int(scaled_loss[0] * height)) + for i, l in enumerate(scaled_loss[1:], start=1): + x = margin + int(i / steps * width) # Scale X properly + y = height - int(l * height) + draw.line([prev_point, (x, y)], fill="blue", width=2) + prev_point = (x, y) + + draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis + draw.line( + [(margin, height), (width + margin, height)], fill="black", width=2 + ) # X-axis + + font = None + try: + font = ImageFont.truetype("arial.ttf", 12) + except IOError: + font = ImageFont.load_default() + + # Add axis labels + draw.text((5, height // 2), "Loss", font=font, fill="black") + draw.text((width // 2, height + 10), "Steps", font=font, fill="black") + + # Add min/max loss values + draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black") + draw.text( + (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" + ) + + metadata = None + if not args.disable_metadata: + metadata = PngInfo() + if prompt is not None: + metadata.add_text("prompt", json.dumps(prompt)) + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata.add_text(x, json.dumps(extra_pnginfo[x])) + + date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + img.save( + os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"), + pnginfo=metadata, + ) + return { + "ui": { + "images": [ + { + "filename": f"{filename_prefix}_{date}.png", + "subfolder": "", + "type": "temp", + } + ] + } + } + + +NODE_CLASS_MAPPINGS = { + "TrainLoraNode": TrainLoraNode, + "SaveLoRANode": SaveLoRA, + "LoraModelLoader": LoraModelLoader, + "LoadImageSetFromFolderNode": LoadImageSetFromFolderNode, + "LossGraphNode": LossGraphNode, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "TrainLoraNode": "Train LoRA", + "SaveLoRANode": "Save LoRA Weights", + "LoraModelLoader": "Load LoRA Model", + "LoadImageSetFromFolderNode": "Load Image Dataset from Folder", + "LossGraphNode": "Plot Loss Graph", +} diff --git a/comfy_extras/nodes_webcam.py b/comfy_extras/nodes_webcam.py index 062b15cf8..5bf80b4c6 100644 --- a/comfy_extras/nodes_webcam.py +++ b/comfy_extras/nodes_webcam.py @@ -23,6 +23,10 @@ class WebcamCapture(nodes.LoadImage): def load_capture(self, image, **kwargs): return super().load_image(folder_paths.get_annotated_filepath(image)) + @classmethod + def IS_CHANGED(cls, image, width, height, capture_on_queue): + return super().IS_CHANGED(image) + NODE_CLASS_MAPPINGS = { "WebcamCapture": WebcamCapture, diff --git a/comfyui_version.py b/comfyui_version.py index f742410b1..fedd3466f 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.39" +__version__ = "0.3.41" diff --git a/execution.py b/execution.py index 15ff7567c..d0012afda 100644 --- a/execution.py +++ b/execution.py @@ -1,23 +1,35 @@ -import sys import copy -import logging -import threading import heapq +import inspect +import logging +import sys +import threading import time import traceback from enum import Enum -import inspect from typing import List, Literal, NamedTuple, Optional import torch -import nodes import comfy.model_management -from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker -from comfy_execution.graph_utils import is_link, GraphBuilder -from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID +import nodes +from comfy_execution.caching import ( + CacheKeySetID, + CacheKeySetInputSignature, + DependencyAwareCache, + HierarchicalCache, + LRUCache, +) +from comfy_execution.graph import ( + DynamicPrompt, + ExecutionBlocker, + ExecutionList, + get_input_info, +) +from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.validation import validate_node_input + class ExecutionResult(Enum): SUCCESS = 0 FAILURE = 1 diff --git a/folder_paths.py b/folder_paths.py index f0b3fd103..9ec952940 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -276,6 +276,9 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str]) def get_full_path(folder_name: str, filename: str) -> str | None: + """ + Get the full path of a file in a folder, has to be a file + """ global folder_names_and_paths folder_name = map_legacy(folder_name) if folder_name not in folder_names_and_paths: @@ -293,6 +296,9 @@ def get_full_path(folder_name: str, filename: str) -> str | None: def get_full_path_or_raise(folder_name: str, filename: str) -> str: + """ + Get the full path of a file in a folder, has to be a file + """ full_path = get_full_path(folder_name, filename) if full_path is None: raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.") @@ -394,3 +400,26 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im os.makedirs(full_output_folder, exist_ok=True) counter = 1 return full_output_folder, filename, counter, subfolder, filename_prefix + +def get_input_subfolders() -> list[str]: + """Returns a list of all subfolder paths in the input directory, recursively. + + Returns: + List of folder paths relative to the input directory, excluding the root directory + """ + input_dir = get_input_directory() + folders = [] + + try: + if not os.path.exists(input_dir): + return [] + + for root, dirs, _ in os.walk(input_dir): + rel_path = os.path.relpath(root, input_dir) + if rel_path != ".": # Only include non-root directories + # Normalize path separators to forward slashes + folders.append(rel_path.replace(os.sep, '/')) + + return sorted(folders) + except FileNotFoundError: + return [] diff --git a/main.py b/main.py index fb1f8d20b..c8c4194d4 100644 --- a/main.py +++ b/main.py @@ -17,7 +17,6 @@ if __name__ == "__main__": os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['DO_NOT_TRACK'] = '1' - setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) def apply_custom_paths(): @@ -238,6 +237,15 @@ def cleanup_temp(): shutil.rmtree(temp_dir, ignore_errors=True) +def setup_database(): + try: + from app.database.db import init_db, dependencies_available + if dependencies_available(): + init_db() + except Exception as e: + logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") + + def start_comfyui(asyncio_loop=None): """ Starts the ComfyUI server using the provided asyncio event loop or creates a new one. @@ -266,6 +274,7 @@ def start_comfyui(asyncio_loop=None): hook_breaker_ac10a0.restore_functions() cuda_malloc_warning() + setup_database() prompt_server.add_routes() hijack_progress(prompt_server) @@ -300,6 +309,9 @@ if __name__ == "__main__": logging.info("Python version: {}".format(sys.version)) logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) + if sys.version_info.major == 3 and sys.version_info.minor < 10: + logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") + event_loop, _, start_all_func = start_comfyui() try: x = start_all_func() diff --git a/nodes.py b/nodes.py index 9278ffda8..c8d0cacb5 100644 --- a/nodes.py +++ b/nodes.py @@ -2067,6 +2067,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ImageQuantize": "Image Quantize", "ImageSharpen": "Image Sharpen", "ImageScaleToTotalPixels": "Scale Image to Total Pixels", + "GetImageSize": "Get Image Size", # _for_testing "VAEDecodeTiled": "VAE Decode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)", @@ -2124,6 +2125,25 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir) + try: + from comfy_config import config_parser + + project_config = config_parser.extract_node_configuration(module_path) + + web_dir_name = project_config.tool_comfy.web + + if web_dir_name: + web_dir_path = os.path.join(module_path, web_dir_name) + + if os.path.isdir(web_dir_path): + project_name = project_config.project.name + + EXTENSION_WEB_DIRS[project_name] = web_dir_path + + logging.info("Automatically register web folder {} for {}".format(web_dir_name, project_name)) + except Exception as e: + logging.warning(f"Unable to parse pyproject.toml due to lack dependency pydantic-settings, please run 'pip install -r requirements.txt': {e}") + if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None: web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY"))) if os.path.isdir(web_dir): @@ -2211,6 +2231,7 @@ def init_builtin_extra_nodes(): "nodes_model_downscale.py", "nodes_images.py", "nodes_video_model.py", + "nodes_train.py", "nodes_sag.py", "nodes_perpneg.py", "nodes_stable3d.py", diff --git a/pyproject.toml b/pyproject.toml index 28a6158e0..c572ad4c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.39" +version = "0.3.41" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index 60174ff57..336ec9d57 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -comfyui-frontend-package==1.21.3 -comfyui-workflow-templates==0.1.23 -comfyui-embedded-docs==0.2.0 +comfyui-frontend-package==1.21.7 +comfyui-workflow-templates==0.1.28 +comfyui-embedded-docs==0.2.2 torch torchsde torchvision @@ -18,6 +18,8 @@ Pillow scipy tqdm psutil +alembic +SQLAlchemy #non essential dependencies: kornia>=0.7.1 @@ -25,3 +27,4 @@ spandrel soundfile av>=14.2.0 pydantic~=2.0 +pydantic-settings~=2.0 diff --git a/server.py b/server.py index 6e283fe31..878b5eeb1 100644 --- a/server.py +++ b/server.py @@ -390,7 +390,7 @@ class PromptServer(): async def view_image(request): if "filename" in request.rel_url.query: filename = request.rel_url.query["filename"] - filename,output_dir = folder_paths.annotated_filepath(filename) + filename, output_dir = folder_paths.annotated_filepath(filename) if not filename: return web.Response(status=400) @@ -476,9 +476,8 @@ class PromptServer(): # Get content type from mimetype, defaulting to 'application/octet-stream' content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream' - # For security, force certain extensions to download instead of display - file_extension = os.path.splitext(filename)[1].lower() - if file_extension in {'.html', '.htm', '.js', '.css'}: + # For security, force certain mimetypes to download instead of display + if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}: content_type = 'application/octet-stream' # Forces download return web.FileResponse( @@ -789,7 +788,7 @@ class PromptServer(): if hasattr(Image, 'Resampling'): resampling = Image.Resampling.BILINEAR else: - resampling = Image.ANTIALIAS + resampling = Image.Resampling.LANCZOS image = ImageOps.contain(image, (max_size, max_size), resampling) type_num = 1 diff --git a/tests-unit/comfy_extras_test/image_stitch_test.py b/tests-unit/comfy_extras_test/image_stitch_test.py index fbaef756c..b5a0f022c 100644 --- a/tests-unit/comfy_extras_test/image_stitch_test.py +++ b/tests-unit/comfy_extras_test/image_stitch_test.py @@ -5,7 +5,10 @@ from unittest.mock import patch, MagicMock mock_nodes = MagicMock() mock_nodes.MAX_RESOLUTION = 16384 -with patch.dict('sys.modules', {'nodes': mock_nodes}): +# Mock server module for PromptServer +mock_server = MagicMock() + +with patch.dict('sys.modules', {'nodes': mock_nodes, 'server': mock_server}): from comfy_extras.nodes_images import ImageStitch diff --git a/tests-unit/folder_paths_test/misc_test.py b/tests-unit/folder_paths_test/misc_test.py new file mode 100644 index 000000000..fcf667453 --- /dev/null +++ b/tests-unit/folder_paths_test/misc_test.py @@ -0,0 +1,51 @@ +import pytest +import os +import tempfile +from folder_paths import get_input_subfolders, set_input_directory + +@pytest.fixture(scope="module") +def mock_folder_structure(): + with tempfile.TemporaryDirectory() as temp_dir: + # Create a nested folder structure + folders = [ + "folder1", + "folder1/subfolder1", + "folder1/subfolder2", + "folder2", + "folder2/deep", + "folder2/deep/nested", + "empty_folder" + ] + + # Create the folders + for folder in folders: + os.makedirs(os.path.join(temp_dir, folder)) + + # Add some files to test they're not included + with open(os.path.join(temp_dir, "root_file.txt"), "w") as f: + f.write("test") + with open(os.path.join(temp_dir, "folder1", "test.txt"), "w") as f: + f.write("test") + + set_input_directory(temp_dir) + yield temp_dir + + +def test_gets_all_folders(mock_folder_structure): + folders = get_input_subfolders() + expected = ["folder1", "folder1/subfolder1", "folder1/subfolder2", + "folder2", "folder2/deep", "folder2/deep/nested", "empty_folder"] + assert sorted(folders) == sorted(expected) + + +def test_handles_nonexistent_input_directory(): + with tempfile.TemporaryDirectory() as temp_dir: + nonexistent = os.path.join(temp_dir, "nonexistent") + set_input_directory(nonexistent) + assert get_input_subfolders() == [] + + +def test_empty_input_directory(): + with tempfile.TemporaryDirectory() as temp_dir: + set_input_directory(temp_dir) + assert get_input_subfolders() == [] # Empty since we don't include root diff --git a/utils/install_util.py b/utils/install_util.py new file mode 100644 index 000000000..0f59bcf91 --- /dev/null +++ b/utils/install_util.py @@ -0,0 +1,18 @@ +from pathlib import Path +import sys + +# The path to the requirements.txt file +requirements_path = Path(__file__).parents[1] / "requirements.txt" + + +def get_missing_requirements_message(): + """The warning message to display when a package is missing.""" + + extra = "" + if sys.flags.no_user_site: + extra = "-s " + return f""" +Please install the updated requirements.txt file by running: +{sys.executable} {extra}-m pip install -r {requirements_path} +If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem. +""".strip()