diff --git a/README.md b/README.md index 1ceaccb3c..0de4a6bb5 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ [![Website][website-shield]][website-url] [![Dynamic JSON Badge][discord-shield]][discord-url] +[![Twitter][twitter-shield]][twitter-url] [![Matrix][matrix-shield]][matrix-url]
[![][github-release-shield]][github-release-link] @@ -20,6 +21,8 @@ [discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total [discord-url]: https://www.comfy.org/discord +[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI +[twitter-url]: https://x.com/ComfyUI [github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver [github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases @@ -62,12 +65,13 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/) - [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/) - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) + - [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/) - Video Models - [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/) - [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/) - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) + - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) - Audio Models - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 000000000..12f18712f --- /dev/null +++ b/alembic.ini @@ -0,0 +1,84 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = alembic_db + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic_db/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +# version_path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +version_path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = sqlite:///user/comfyui.db + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME diff --git a/alembic_db/README.md b/alembic_db/README.md new file mode 100644 index 000000000..3b808c7ca --- /dev/null +++ b/alembic_db/README.md @@ -0,0 +1,4 @@ +## Generate new revision + +1. Update models in `/app/database/models.py` +2. Run `alembic revision --autogenerate -m "{your message}"` diff --git a/alembic_db/env.py b/alembic_db/env.py new file mode 100644 index 000000000..4d7770679 --- /dev/null +++ b/alembic_db/env.py @@ -0,0 +1,64 @@ +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + + +from app.database.models import Base +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + Calls to context.execute() here emit the given string to the + script output. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + In this scenario we need to create an Engine + and associate a connection with the context. + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic_db/script.py.mako b/alembic_db/script.py.mako new file mode 100644 index 000000000..480b130d6 --- /dev/null +++ b/alembic_db/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/app/database/db.py b/app/database/db.py new file mode 100644 index 000000000..1de8b80ed --- /dev/null +++ b/app/database/db.py @@ -0,0 +1,112 @@ +import logging +import os +import shutil +from app.logger import log_startup_warning +from utils.install_util import get_missing_requirements_message +from comfy.cli_args import args + +_DB_AVAILABLE = False +Session = None + + +try: + from alembic import command + from alembic.config import Config + from alembic.runtime.migration import MigrationContext + from alembic.script import ScriptDirectory + from sqlalchemy import create_engine + from sqlalchemy.orm import sessionmaker + + _DB_AVAILABLE = True +except ImportError as e: + log_startup_warning( + f""" +------------------------------------------------------------------------ +Error importing dependencies: {e} +{get_missing_requirements_message()} +This error is happening because ComfyUI now uses a local sqlite database. +------------------------------------------------------------------------ +""".strip() + ) + + +def dependencies_available(): + """ + Temporary function to check if the dependencies are available + """ + return _DB_AVAILABLE + + +def can_create_session(): + """ + Temporary function to check if the database is available to create a session + During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created + """ + return dependencies_available() and Session is not None + + +def get_alembic_config(): + root_path = os.path.join(os.path.dirname(__file__), "../..") + config_path = os.path.abspath(os.path.join(root_path, "alembic.ini")) + scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) + + config = Config(config_path) + config.set_main_option("script_location", scripts_path) + config.set_main_option("sqlalchemy.url", args.database_url) + + return config + + +def get_db_path(): + url = args.database_url + if url.startswith("sqlite:///"): + return url.split("///")[1] + else: + raise ValueError(f"Unsupported database URL '{url}'.") + + +def init_db(): + db_url = args.database_url + logging.debug(f"Database URL: {db_url}") + db_path = get_db_path() + db_exists = os.path.exists(db_path) + + config = get_alembic_config() + + # Check if we need to upgrade + engine = create_engine(db_url) + conn = engine.connect() + + context = MigrationContext.configure(conn) + current_rev = context.get_current_revision() + + script = ScriptDirectory.from_config(config) + target_rev = script.get_current_head() + + if target_rev is None: + logging.warning("No target revision found.") + elif current_rev != target_rev: + # Backup the database pre upgrade + backup_path = db_path + ".bkp" + if db_exists: + shutil.copy(db_path, backup_path) + else: + backup_path = None + + try: + command.upgrade(config, target_rev) + logging.info(f"Database upgraded from {current_rev} to {target_rev}") + except Exception as e: + if backup_path: + # Restore the database from backup if upgrade fails + shutil.copy(backup_path, db_path) + os.remove(backup_path) + logging.exception("Error upgrading database: ") + raise e + + global Session + Session = sessionmaker(bind=engine) + + +def create_session(): + return Session() diff --git a/app/database/models.py b/app/database/models.py new file mode 100644 index 000000000..6facfb8f2 --- /dev/null +++ b/app/database/models.py @@ -0,0 +1,14 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +def to_dict(obj): + fields = obj.__table__.columns.keys() + return { + field: (val.to_dict() if hasattr(val, "to_dict") else val) + for field in fields + if (val := getattr(obj, field)) + } + +# TODO: Define models here diff --git a/app/frontend_management.py b/app/frontend_management.py index d9ef8c921..001ebbecb 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -16,26 +16,17 @@ from importlib.metadata import version import requests from typing_extensions import NotRequired +from utils.install_util import get_missing_requirements_message, requirements_path + from comfy.cli_args import DEFAULT_VERSION_STRING import app.logger -# The path to the requirements.txt file -req_path = Path(__file__).parents[1] / "requirements.txt" - def frontend_install_warning_message(): - """The warning message to display when the frontend version is not up to date.""" - - extra = "" - if sys.flags.no_user_site: - extra = "-s " return f""" -Please install the updated requirements.txt file by running: -{sys.executable} {extra}-m pip install -r {req_path} +{get_missing_requirements_message()} This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead. - -If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem """.strip() @@ -48,7 +39,7 @@ def check_frontend_version(): try: frontend_version_str = version("comfyui-frontend-package") frontend_version = parse_version(frontend_version_str) - with open(req_path, "r", encoding="utf-8") as f: + with open(requirements_path, "r", encoding="utf-8") as f: required_frontend = parse_version(f.readline().split("=")[-1]) if frontend_version < required_frontend: app.logger.log_startup_warning( @@ -121,9 +112,22 @@ class FrontEndProvider: response.raise_for_status() # Raises an HTTPError if the response was an error return response.json() + @cached_property + def latest_prerelease(self) -> Release: + """Get the latest pre-release version - even if it's older than the latest release""" + release = [release for release in self.all_releases if release["prerelease"]] + + if not release: + raise ValueError("No pre-releases found") + + # GitHub returns releases in reverse chronological order, so first is latest + return release[0] + def get_release(self, version: str) -> Release: if version == "latest": return self.latest_release + elif version == "prerelease": + return self.latest_prerelease else: for release in self.all_releases: if release["tag_name"] in [version, f"v{version}"]: @@ -230,7 +234,7 @@ comfyui-workflow-templates is not installed. Raises: argparse.ArgumentTypeError: If the version string is invalid. """ - VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$" + VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+[-._a-zA-Z0-9]*|latest|prerelease)$" match_result = re.match(VERSION_PATTERN, value) if match_result is None: raise argparse.ArgumentTypeError(f"Invalid version string: {value}") diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 4fb675f99..741ecac3f 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -203,6 +203,11 @@ parser.add_argument( help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)", ) +database_default_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") +) +parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") + if comfy.options.args_parsing: args = parser.parse_args() else: diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 470eb9fdb..071b98332 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -37,6 +37,8 @@ class IO(StrEnum): CONTROL_NET = "CONTROL_NET" VAE = "VAE" MODEL = "MODEL" + LORA_MODEL = "LORA_MODEL" + LOSS_MAP = "LOSS_MAP" CLIP_VISION = "CLIP_VISION" CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT" STYLE_MODEL = "STYLE_MODEL" diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 11483e21d..9a47b86f2 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -390,8 +390,9 @@ class ControlLora(ControlNet): pass for k in self.control_weights: - if k not in {"lora_controlnet"}: - comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) + if (k not in {"lora_controlnet"}): + if (k.endswith(".up") or k.endswith(".down") or k.endswith(".weight") or k.endswith(".bias")) and ("__" not in k): + comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) def copy(self): c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index fbdf6f554..8030048fc 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,4 +1,5 @@ import math +from functools import partial from scipy import integrate import torch @@ -142,6 +143,33 @@ class BrownianTreeNoiseSampler: return self.tree(t0, t1) / (t1 - t0).abs().sqrt() +def sigma_to_half_log_snr(sigma, model_sampling): + """Convert sigma to half-logSNR log(alpha_t / sigma_t).""" + if isinstance(model_sampling, comfy.model_sampling.CONST): + # log((1 - t) / t) = log((1 - sigma) / sigma) + return sigma.logit().neg() + return sigma.log().neg() + + +def half_log_snr_to_sigma(half_log_snr, model_sampling): + """Convert half-logSNR log(alpha_t / sigma_t) to sigma.""" + if isinstance(model_sampling, comfy.model_sampling.CONST): + # 1 / (1 + exp(half_log_snr)) + return half_log_snr.neg().sigmoid() + return half_log_snr.neg().exp() + + +def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4): + """Adjust the first sigma to avoid invalid logSNR.""" + if len(sigmas) <= 1: + return sigmas + if isinstance(model_sampling, comfy.model_sampling.CONST): + if sigmas[0] >= 1: + sigmas = sigmas.clone() + sigmas[0] = model_sampling.percent_to_sigma(percent_offset) + return sigmas + + @torch.no_grad() def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" @@ -753,6 +781,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No old_denoised = denoised return x + @torch.no_grad() def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): """DPM-Solver++(2M) SDE.""" @@ -768,9 +797,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + old_denoised = None - h_last = None - h = None + h, h_last = None, None for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -781,26 +813,29 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl x = denoised else: # DPM-Solver++(2M) SDE - t, s = -sigmas[i].log(), -sigmas[i + 1].log() - h = s - t - eta_h = eta * h + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + h_eta = h * (eta + 1) - x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised + alpha_t = sigmas[i + 1] * lambda_t.exp() + + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised if old_denoised is not None: r = h_last / h if solver_type == 'heun': - x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) + x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised) elif solver_type == 'midpoint': - x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) + x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised) - if eta: - x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + if eta > 0 and s_noise > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise old_denoised = denoised h_last = h return x + @torch.no_grad() def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """DPM-Solver++(3M) SDE.""" @@ -814,6 +849,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + denoised_1, denoised_2 = None, None h, h_1, h_2 = None, None, None @@ -825,13 +864,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl # Denoising step x = denoised else: - t, s = -sigmas[i].log(), -sigmas[i + 1].log() - h = s - t + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s h_eta = h * (eta + 1) - x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised + alpha_t = sigmas[i + 1] * lambda_t.exp() + + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised if h_2 is not None: + # DPM-Solver++(3M) SDE r0 = h_1 / h r1 = h_2 / h d1_0 = (denoised - denoised_1) / r0 @@ -840,20 +882,22 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl d2 = (d1_0 - d1_1) / (r0 + r1) phi_2 = h_eta.neg().expm1() / h_eta + 1 phi_3 = phi_2 / h_eta - 0.5 - x = x + phi_2 * d1 - phi_3 * d2 + x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2 elif h_1 is not None: + # DPM-Solver++(2M) SDE r = h_1 / h d = (denoised - denoised_1) / r phi_2 = h_eta.neg().expm1() / h_eta + 1 - x = x + phi_2 * d + x = x + (alpha_t * phi_2) * d - if eta: + if eta > 0 and s_noise > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise denoised_1, denoised_2 = denoised, denoised_1 h_1, h_2 = h, h_1 return x + @torch.no_grad() def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): if len(sigmas) <= 1: @@ -863,6 +907,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) + @torch.no_grad() def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): if len(sigmas) <= 1: @@ -872,6 +917,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + @torch.no_grad() def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): if len(sigmas) <= 1: @@ -1449,12 +1495,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None old_denoised = denoised return x + @torch.no_grad() def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): - ''' - SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2 - Arxiv: https://arxiv.org/abs/2305.14267 - ''' + """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. + arXiv: https://arxiv.org/abs/2305.14267 + """ extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler @@ -1462,6 +1508,11 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non inject_noise = eta > 0 and s_noise > 0 + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: @@ -1469,80 +1520,96 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non if sigmas[i + 1] == 0: x = denoised else: - t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() - h = t_next - t + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s h_eta = h * (eta + 1) - s = t + r * h + lambda_s_1 = lambda_s + r * h fac = 1 / (2 * r) - sigma_s = s.neg().exp() + sigma_s_1 = sigma_fn(lambda_s_1) + + # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1() if inject_noise: + # 0 < r < 1 noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt() - noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() - noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1]) + noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt() + noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1]) # Step 1 - x_2 = (coeff_1 + 1) * x - coeff_1 * denoised - if inject_noise: - x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise - denoised_2 = model(x_2, sigma_s * s_in, **extra_args) - - # Step 2 - denoised_d = (1 - fac) * denoised + fac * denoised_2 - x = (coeff_2 + 1) * x - coeff_2 * denoised_d - if inject_noise: - x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise - return x - -@torch.no_grad() -def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): - ''' - SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3 - Arxiv: https://arxiv.org/abs/2305.14267 - ''' - extra_args = {} if extra_args is None else extra_args - seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_in = x.new_ones([x.shape[0]]) - - inject_noise = eta > 0 and s_noise > 0 - - for i in trange(len(sigmas) - 1, disable=disable): - denoised = model(x, sigmas[i] * s_in, **extra_args) - if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - if sigmas[i + 1] == 0: - x = denoised - else: - t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() - h = t_next - t - h_eta = h * (eta + 1) - s_1 = t + r_1 * h - s_2 = t + r_2 * h - sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp() - - coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() - if inject_noise: - noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() - noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt() - noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() - noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) - - # Step 1 - x_2 = (coeff_1 + 1) * x - coeff_1 * denoised + x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised if inject_noise: x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) # Step 2 - x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) + denoised_d = (1 - fac) * denoised + fac * denoised_2 + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d + if inject_noise: + x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise + return x + + +@torch.no_grad() +def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): + """SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3. + arXiv: https://arxiv.org/abs/2305.14267 + """ + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + inject_noise = eta > 0 and s_noise > 0 + + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + x = denoised + else: + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + h_eta = h * (eta + 1) + lambda_s_1 = lambda_s + r_1 * h + lambda_s_2 = lambda_s + r_2 * h + sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2) + + # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_s_2 = sigma_s_2 * lambda_s_2.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() + + coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() + if inject_noise: + # 0 < r_1 < r_2 < 1 + noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() + noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt() + noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt() + noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) + + # Step 1 + x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised + if inject_noise: + x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise + denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) + + # Step 2 + x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) if inject_noise: x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) # Step 3 - x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) if inject_noise: x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise return x diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index a12f892d2..5c4356a3f 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -26,16 +26,6 @@ from torch import nn from comfy.ldm.modules.attention import optimized_attention -def apply_rotary_pos_emb( - t: torch.Tensor, - freqs: torch.Tensor, -) -> torch.Tensor: - t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() - t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] - t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) - return t_out - - def get_normalization(name: str, channels: int, weight_args={}, operations=None): if name == "I": return nn.Identity() diff --git a/comfy/ldm/cosmos/position_embedding.py b/comfy/ldm/cosmos/position_embedding.py index 4d6a58dba..c925811d4 100644 --- a/comfy/ldm/cosmos/position_embedding.py +++ b/comfy/ldm/cosmos/position_embedding.py @@ -66,15 +66,16 @@ class VideoRopePosition3DEmb(VideoPositionEmb): h_extrapolation_ratio: float = 1.0, w_extrapolation_ratio: float = 1.0, t_extrapolation_ratio: float = 1.0, + enable_fps_modulation: bool = True, device=None, **kwargs, # used for compatibility with other positional embeddings; unused in this class ): del kwargs super().__init__() - self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device)) self.base_fps = base_fps self.max_h = len_h self.max_w = len_w + self.enable_fps_modulation = enable_fps_modulation dim = head_dim dim_h = dim // 6 * 2 @@ -132,21 +133,19 @@ class VideoRopePosition3DEmb(VideoPositionEmb): temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device)) B, T, H, W, _ = B_T_H_W_C + seq = torch.arange(max(H, W, T), dtype=torch.float, device=device) uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max()) assert ( uniform_fps or B == 1 or T == 1 ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" - assert ( - H <= self.max_h and W <= self.max_w - ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})" - half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs) - half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs) + half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs) + half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs) # apply sequence scaling in temporal dimension - if fps is None: # image case - half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs) + if fps is None or self.enable_fps_modulation is False: # image case + half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs) else: - half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs) + half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs) half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1) half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1) diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py new file mode 100644 index 000000000..316117f77 --- /dev/null +++ b/comfy/ldm/cosmos/predict2.py @@ -0,0 +1,864 @@ +# original code from: https://github.com/nvidia-cosmos/cosmos-predict2 + +import torch +from torch import nn +from einops import rearrange +from einops.layers.torch import Rearrange +import logging +from typing import Callable, Optional, Tuple +import math + +from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis +from torchvision import transforms + +from comfy.ldm.modules.attention import optimized_attention + +def apply_rotary_pos_emb( + t: torch.Tensor, + freqs: torch.Tensor, +) -> torch.Tensor: + t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() + t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] + t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) + return t_out + + +# ---------------------- Feed Forward Network ----------------------- +class GPT2FeedForward(nn.Module): + def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None: + super().__init__() + self.activation = nn.GELU() + self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype) + self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype) + + self._layer_id = None + self._dim = d_model + self._hidden_dim = d_ff + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layer1(x) + + x = self.activation(x) + x = self.layer2(x) + return x + + +def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor: + """Computes multi-head attention using PyTorch's native implementation. + + This function provides a PyTorch backend alternative to Transformer Engine's attention operation. + It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product + attention, and rearranges the output back to the original format. + + The input tensor names use the following dimension conventions: + + - B: batch size + - S: sequence length + - H: number of attention heads + - D: head dimension + + Args: + q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim) + k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim) + v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim) + + Returns: + Attention output tensor with shape (batch, seq_len, n_heads * head_dim) + """ + in_q_shape = q_B_S_H_D.shape + in_k_shape = k_B_S_H_D.shape + q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) + k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True) + + +class Attention(nn.Module): + """ + A flexible attention module supporting both self-attention and cross-attention mechanisms. + + This module implements a multi-head attention layer that can operate in either self-attention + or cross-attention mode. The mode is determined by whether a context dimension is provided. + The implementation uses scaled dot-product attention and supports optional bias terms and + dropout regularization. + + Args: + query_dim (int): The dimensionality of the query vectors. + context_dim (int, optional): The dimensionality of the context (key/value) vectors. + If None, the module operates in self-attention mode using query_dim. Default: None + n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8 + head_dim (int, optional): The dimension of each attention head. Default: 64 + dropout (float, optional): Dropout probability applied to the output. Default: 0.0 + qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd" + backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine" + + Examples: + >>> # Self-attention with 512 dimensions and 8 heads + >>> self_attn = Attention(query_dim=512) + >>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim) + >>> out = self_attn(x) # (32, 16, 512) + + >>> # Cross-attention + >>> cross_attn = Attention(query_dim=512, context_dim=256) + >>> query = torch.randn(32, 16, 512) + >>> context = torch.randn(32, 8, 256) + >>> out = cross_attn(query, context) # (32, 16, 512) + """ + + def __init__( + self, + query_dim: int, + context_dim: Optional[int] = None, + n_heads: int = 8, + head_dim: int = 64, + dropout: float = 0.0, + device=None, + dtype=None, + operations=None, + ) -> None: + super().__init__() + logging.debug( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{n_heads} heads with a dimension of {head_dim}." + ) + self.is_selfattn = context_dim is None # self attention + + context_dim = query_dim if context_dim is None else context_dim + inner_dim = head_dim * n_heads + + self.n_heads = n_heads + self.head_dim = head_dim + self.query_dim = query_dim + self.context_dim = context_dim + + self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.v_norm = nn.Identity() + + self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) + self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity() + + self.attn_op = torch_attention_op + + self._query_dim = query_dim + self._context_dim = context_dim + self._inner_dim = inner_dim + + def compute_qkv( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + rope_emb: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = self.q_proj(x) + context = x if context is None else context + k = self.k_proj(context) + v = self.v_proj(context) + q, k, v = map( + lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim), + (q, k, v), + ) + + def apply_norm_and_rotary_pos_emb( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = self.q_norm(q) + k = self.k_norm(k) + v = self.v_norm(v) + if self.is_selfattn and rope_emb is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb) + k = apply_rotary_pos_emb(k, rope_emb) + return q, k, v + + q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb) + + return q, k, v + + def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + result = self.attn_op(q, k, v) # [B, S, H, D] + return self.output_dropout(self.output_proj(result)) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + rope_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (Tensor): The query tensor of shape [B, Mq, K] + context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None + """ + q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) + return self.compute_attention(q, k, v) + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor: + assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}" + timesteps = timesteps_B_T.flatten().float() + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1]) + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None): + super().__init__() + logging.debug( + f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." + ) + self.in_dim = in_features + self.out_dim = out_features + self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype) + self.activation = nn.SiLU() + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype) + else: + self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype) + + def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + emb = self.linear_1(sample) + emb = self.activation(emb) + emb = self.linear_2(emb) + + if self.use_adaln_lora: + adaln_lora_B_T_3D = emb + emb_B_T_D = sample + else: + adaln_lora_B_T_3D = None + emb_B_T_D = emb + + return emb_B_T_D, adaln_lora_B_T_3D + + +class PatchEmbed(nn.Module): + """ + PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, + depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, + making it suitable for video and image processing tasks. It supports dividing the input into patches + and embedding each patch into a vector of size `out_channels`. + + Parameters: + - spatial_patch_size (int): The size of each spatial patch. + - temporal_patch_size (int): The size of each temporal patch. + - in_channels (int): Number of input channels. Default: 3. + - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. + - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. + """ + + def __init__( + self, + spatial_patch_size: int, + temporal_patch_size: int, + in_channels: int = 3, + out_channels: int = 768, + device=None, dtype=None, operations=None + ): + super().__init__() + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + + self.proj = nn.Sequential( + Rearrange( + "b c (t r) (h m) (w n) -> b t h w (c r m n)", + r=temporal_patch_size, + m=spatial_patch_size, + n=spatial_patch_size, + ), + operations.Linear( + in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype + ), + ) + self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the PatchEmbed module. + + Parameters: + - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where + B is the batch size, + C is the number of channels, + T is the temporal dimension, + H is the height, and + W is the width of the input. + + Returns: + - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. + """ + assert x.dim() == 5 + _, _, T, H, W = x.shape + assert ( + H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 + ), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}" + assert T % self.temporal_patch_size == 0 + x = self.proj(x) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of video DiT. + """ + + def __init__( + self, + hidden_size: int, + spatial_patch_size: int, + temporal_patch_size: int, + out_channels: int, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + device=None, dtype=None, operations=None + ): + super().__init__() + self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = operations.Linear( + hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype + ) + self.hidden_size = hidden_size + self.n_adaln_chunks = 2 + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + if use_adaln_lora: + self.adaln_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype), + ) + else: + self.adaln_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype) + ) + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_T_3D is not None + shift_B_T_D, scale_B_T_D = ( + self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size] + ).chunk(2, dim=-1) + else: + shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) + + shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange( + scale_B_T_D, "b t d -> b t 1 1 d" + ) + + def _fn( + _x_B_T_H_W_D: torch.Tensor, + _norm_layer: nn.Module, + _scale_B_T_1_1_D: torch.Tensor, + _shift_B_T_1_1_D: torch.Tensor, + ) -> torch.Tensor: + return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D + + x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D) + x_B_T_H_W_O = self.linear(x_B_T_H_W_D) + return x_B_T_H_W_O + + +class Block(nn.Module): + """ + A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation. + Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation. + + Parameters: + x_dim (int): Dimension of input features + context_dim (int): Dimension of context features for cross-attention + num_heads (int): Number of attention heads + mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0 + use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False + adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256 + + The block applies the following sequence: + 1. Self-attention with AdaLN modulation + 2. Cross-attention with AdaLN modulation + 3. MLP with AdaLN modulation + + Each component uses skip connections and layer normalization. + """ + + def __init__( + self, + x_dim: int, + context_dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + self.x_dim = x_dim + self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations) + + self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.cross_attn = Attention( + x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations + ) + + self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations) + + self.use_adaln_lora = use_adaln_lora + if self.use_adaln_lora: + self.adaln_modulation_self_attn = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + self.adaln_modulation_cross_attn = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + self.adaln_modulation_mlp = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + else: + self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + crossattn_emb: torch.Tensor, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if extra_per_block_pos_emb is not None: + x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb + + if self.use_adaln_lora: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = ( + self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( + self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = ( + self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + else: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1) + + # Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting + shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d") + scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d") + gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d") + + shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d") + scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d") + gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d") + + shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d") + scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d") + gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d") + + B, T, H, W, D = x_B_T_H_W_D.shape + + def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D): + return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D + + normalized_x_B_T_H_W_D = _fn( + x_B_T_H_W_D, + self.layer_norm_self_attn, + scale_self_attn_B_T_1_1_D, + shift_self_attn_B_T_1_1_D, + ) + result_B_T_H_W_D = rearrange( + self.self_attn( + # normalized_x_B_T_HW_D, + rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + None, + rope_emb=rope_emb_L_1_1_D, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D + + def _x_fn( + _x_B_T_H_W_D: torch.Tensor, + layer_norm_cross_attn: Callable, + _scale_cross_attn_B_T_1_1_D: torch.Tensor, + _shift_cross_attn_B_T_1_1_D: torch.Tensor, + ) -> torch.Tensor: + _normalized_x_B_T_H_W_D = _fn( + _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D + ) + _result_B_T_H_W_D = rearrange( + self.cross_attn( + rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + crossattn_emb, + rope_emb=rope_emb_L_1_1_D, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + return _result_B_T_H_W_D + + result_B_T_H_W_D = _x_fn( + x_B_T_H_W_D, + self.layer_norm_cross_attn, + scale_cross_attn_B_T_1_1_D, + shift_cross_attn_B_T_1_1_D, + ) + x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D + + normalized_x_B_T_H_W_D = _fn( + x_B_T_H_W_D, + self.layer_norm_mlp, + scale_mlp_B_T_1_1_D, + shift_mlp_B_T_1_1_D, + ) + result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D) + x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D + return x_B_T_H_W_D + + +class MiniTrainDIT(nn.Module): + """ + A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1) + A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. + + Args: + max_img_h (int): Maximum height of the input images. + max_img_w (int): Maximum width of the input images. + max_frames (int): Maximum number of frames in the video sequence. + in_channels (int): Number of input channels (e.g., RGB channels for color images). + out_channels (int): Number of output channels. + patch_spatial (tuple): Spatial resolution of patches for input processing. + patch_temporal (int): Temporal resolution of patches for input processing. + concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. + model_channels (int): Base number of channels used throughout the model. + num_blocks (int): Number of transformer blocks. + num_heads (int): Number of heads in the multi-head attention layers. + mlp_ratio (float): Expansion ratio for MLP blocks. + crossattn_emb_channels (int): Number of embedding channels for cross-attention. + pos_emb_cls (str): Type of positional embeddings. + pos_emb_learnable (bool): Whether positional embeddings are learnable. + pos_emb_interpolation (str): Method for interpolating positional embeddings. + min_fps (int): Minimum frames per second. + max_fps (int): Maximum frames per second. + use_adaln_lora (bool): Whether to use AdaLN-LoRA. + adaln_lora_dim (int): Dimension for AdaLN-LoRA. + rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. + rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. + rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. + extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. + extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. + extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. + extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. + """ + + def __init__( + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: int, # tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + mlp_ratio: float = 4.0, + # cross attention settings + crossattn_emb_channels: int = 1024, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + min_fps: int = 1, + max_fps: int = 30, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = False, + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + rope_enable_fps_modulation: bool = True, + image_model=None, + device=None, + dtype=None, + operations=None, + ) -> None: + super().__init__() + self.dtype = dtype + self.max_img_h = max_img_h + self.max_img_w = max_img_w + self.max_frames = max_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.num_heads = num_heads + self.num_blocks = num_blocks + self.model_channels = model_channels + self.concat_padding_mask = concat_padding_mask + # positional embedding settings + self.pos_emb_cls = pos_emb_cls + self.pos_emb_learnable = pos_emb_learnable + self.pos_emb_interpolation = pos_emb_interpolation + self.min_fps = min_fps + self.max_fps = max_fps + self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio + self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio + self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio + self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb + self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio + self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio + self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio + self.rope_enable_fps_modulation = rope_enable_fps_modulation + + self.build_pos_embed(device=device, dtype=dtype) + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + self.t_embedder = nn.Sequential( + Timesteps(model_channels), + TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,), + ) + + in_channels = in_channels + 1 if concat_padding_mask else in_channels + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + device=device, dtype=dtype, operations=operations, + ) + + self.blocks = nn.ModuleList( + [ + Block( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + device=device, dtype=dtype, operations=operations, + ) + for _ in range(num_blocks) + ] + ) + + self.final_layer = FinalLayer( + hidden_size=self.model_channels, + spatial_patch_size=self.patch_spatial, + temporal_patch_size=self.patch_temporal, + out_channels=self.out_channels, + use_adaln_lora=self.use_adaln_lora, + adaln_lora_dim=self.adaln_lora_dim, + device=device, dtype=dtype, operations=operations, + ) + + self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype) + + def build_pos_embed(self, device=None, dtype=None) -> None: + if self.pos_emb_cls == "rope3d": + cls_type = VideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + max_fps=self.max_fps, + min_fps=self.min_fps, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + enable_fps_modulation=self.rope_enable_fps_modulation, + device=device, + ) + self.pos_embedder = cls_type( + **kwargs, # type: ignore + ) + + if self.extra_per_block_abs_pos_emb: + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + kwargs["device"] = device + kwargs["dtype"] = dtype + self.extra_pos_embedder = LearnablePosEmbAxis( + **kwargs, # type: ignore + ) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the + `self.pos_embedder` with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + if padding_mask is None: + padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device) + else: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D] + + return x_B_T_H_W_D, None, extra_pos_emb + + def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor: + x_B_C_Tt_Hp_Wp = rearrange( + x_B_T_H_W_M, + "B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + t=self.patch_temporal, + ) + return x_B_C_Tt_Hp_Wp + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + x_B_C_T_H_W = x + timesteps_B_T = timesteps + crossattn_emb = context + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + """ + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x_B_C_T_H_W, + fps=fps, + padding_mask=padding_mask, + ) + + if timesteps_B_T.ndim == 1: + timesteps_B_T = timesteps_B_T.unsqueeze(1) + t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype)) + t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D) + + # for logging purpose + affline_scale_log_info = {} + affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach() + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = t_embedding_B_T_D + self.crossattn_emb = crossattn_emb + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + assert ( + x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + ), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}" + + block_kwargs = { + "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0), + "adaln_lora_B_T_3D": adaln_lora_B_T_3D, + "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + for block in self.blocks: + x_B_T_H_W_D = block( + x_B_T_H_W_D, + t_embedding_B_T_D, + crossattn_emb, + **block_kwargs, + ) + + x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) + x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O) + return x_B_C_Tt_Hp_Wp diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py index 5322c4891..dbd2a47c0 100644 --- a/comfy/ldm/flux/controlnet.py +++ b/comfy/ldm/flux/controlnet.py @@ -121,6 +121,9 @@ class ControlNetFlux(Flux): if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + # running on sequences img img = self.img_in(img) @@ -174,7 +177,7 @@ class ControlNetFlux(Flux): out["output"] = out_output[:self.main_model_single] return out - def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): + def forward(self, x, timesteps, context, y=None, guidance=None, hint=None, **kwargs): patch_size = 2 if self.latent_input: hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size)) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index ef4ba4106..846703d52 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -101,6 +101,10 @@ class Flux(nn.Module): transformer_options={}, attn_mask: Tensor = None, ) -> Tensor: + + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -155,6 +159,9 @@ class Flux(nn.Module): if add is not None: img += add + if img.dtype == torch.float16: + img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504) + img = torch.cat((txt, img), 1) for i, block in enumerate(self.single_blocks): @@ -188,7 +195,7 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y=None, guidance=None, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape patch_size = self.patch_size x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 056e101a4..ad9a7daea 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -261,8 +261,8 @@ class CrossAttention(nn.Module): self.heads = heads self.dim_head = dim_head - self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device) - self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device) + self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 2cb77d85d..35d2270ee 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -753,7 +753,7 @@ class BasicTransformerBlock(nn.Module): for p in patch: n = p(n, extra_options) - x += n + x = n + x if "middle_patch" in transformer_patches: patch = transformer_patches["middle_patch"] for p in patch: @@ -793,12 +793,12 @@ class BasicTransformerBlock(nn.Module): for p in patch: n = p(n, extra_options) - x += n + x = n + x if self.is_res: x_skip = x x = self.ff(self.norm3(x)) if self.is_res: - x += x_skip + x = x_skip + x return x diff --git a/comfy/model_base.py b/comfy/model_base.py index e0c2bcaa8..cb7689e84 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -34,6 +34,7 @@ import comfy.ldm.flux.model import comfy.ldm.lightricks.model import comfy.ldm.hunyuan_video.model import comfy.ldm.cosmos.model +import comfy.ldm.cosmos.predict2 import comfy.ldm.lumina.model import comfy.ldm.wan.model import comfy.ldm.hunyuan3d.model @@ -48,6 +49,7 @@ import comfy.ops from enum import Enum from . import utils import comfy.latent_formats +import comfy.model_sampling import math from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -63,38 +65,39 @@ class ModelType(Enum): V_PREDICTION_CONTINUOUS = 7 FLUX = 8 IMG_TO_IMG = 9 - - -from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV + FLOW_COSMOS = 10 def model_sampling(model_config, model_type): - s = ModelSamplingDiscrete + s = comfy.model_sampling.ModelSamplingDiscrete if model_type == ModelType.EPS: - c = EPS + c = comfy.model_sampling.EPS elif model_type == ModelType.V_PREDICTION: - c = V_PREDICTION + c = comfy.model_sampling.V_PREDICTION elif model_type == ModelType.V_PREDICTION_EDM: - c = V_PREDICTION - s = ModelSamplingContinuousEDM + c = comfy.model_sampling.V_PREDICTION + s = comfy.model_sampling.ModelSamplingContinuousEDM elif model_type == ModelType.FLOW: c = comfy.model_sampling.CONST s = comfy.model_sampling.ModelSamplingDiscreteFlow elif model_type == ModelType.STABLE_CASCADE: - c = EPS - s = StableCascadeSampling + c = comfy.model_sampling.EPS + s = comfy.model_sampling.StableCascadeSampling elif model_type == ModelType.EDM: - c = EDM - s = ModelSamplingContinuousEDM + c = comfy.model_sampling.EDM + s = comfy.model_sampling.ModelSamplingContinuousEDM elif model_type == ModelType.V_PREDICTION_CONTINUOUS: - c = V_PREDICTION - s = ModelSamplingContinuousV + c = comfy.model_sampling.V_PREDICTION + s = comfy.model_sampling.ModelSamplingContinuousV elif model_type == ModelType.FLUX: c = comfy.model_sampling.CONST s = comfy.model_sampling.ModelSamplingFlux elif model_type == ModelType.IMG_TO_IMG: c = comfy.model_sampling.IMG_TO_IMG + elif model_type == ModelType.FLOW_COSMOS: + c = comfy.model_sampling.COSMOS_RFLOW + s = comfy.model_sampling.ModelSamplingCosmosRFlow class ModelSampling(s, c): pass @@ -998,6 +1001,43 @@ class CosmosVideo(BaseModel): latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image) return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5) +class CosmosPredict2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW_COSMOS, image_to_video=False, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.predict2.MiniTrainDIT) + self.image_to_video = image_to_video + if self.image_to_video: + self.concat_keys = ("mask_inverted",) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if denoise_mask is not None: + out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask) + + out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None)) + return out + + def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): + if denoise_mask is None: + return timestep + condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True) + c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1 + out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4]) + return out + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)) + sigma_noise_augmentation = 0 #TODO + if sigma_noise_augmentation != 0: + latent_image = latent_image + noise + latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image) + sigma = (sigma / (sigma + 1)) + return latent_image / (1.0 - sigma) + class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 74f539598..4aa90d3b6 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -407,6 +407,58 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["text_emb_dim"] = 2048 return dit_config + if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2 + dit_config = {} + dit_config["image_model"] = "cosmos_predict2" + dit_config["max_img_h"] = 240 + dit_config["max_img_w"] = 240 + dit_config["max_frames"] = 128 + concat_padding_mask = True + dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask) + dit_config["out_channels"] = 16 + dit_config["patch_spatial"] = 2 + dit_config["patch_temporal"] = 1 + dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0] + dit_config["concat_padding_mask"] = concat_padding_mask + dit_config["crossattn_emb_channels"] = 1024 + dit_config["pos_emb_cls"] = "rope3d" + dit_config["pos_emb_learnable"] = True + dit_config["pos_emb_interpolation"] = "crop" + dit_config["min_fps"] = 1 + dit_config["max_fps"] = 30 + + dit_config["use_adaln_lora"] = True + dit_config["adaln_lora_dim"] = 256 + if dit_config["model_channels"] == 2048: + dit_config["num_blocks"] = 28 + dit_config["num_heads"] = 16 + elif dit_config["model_channels"] == 5120: + dit_config["num_blocks"] = 36 + dit_config["num_heads"] = 40 + + if dit_config["in_channels"] == 16: + dit_config["extra_per_block_abs_pos_emb"] = False + dit_config["rope_h_extrapolation_ratio"] = 4.0 + dit_config["rope_w_extrapolation_ratio"] = 4.0 + dit_config["rope_t_extrapolation_ratio"] = 1.0 + elif dit_config["in_channels"] == 17: # img to video + if dit_config["model_channels"] == 2048: + dit_config["extra_per_block_abs_pos_emb"] = False + dit_config["rope_h_extrapolation_ratio"] = 3.0 + dit_config["rope_w_extrapolation_ratio"] = 3.0 + dit_config["rope_t_extrapolation_ratio"] = 1.0 + elif dit_config["model_channels"] == 5120: + dit_config["rope_h_extrapolation_ratio"] = 2.0 + dit_config["rope_w_extrapolation_ratio"] = 2.0 + dit_config["rope_t_extrapolation_ratio"] = 0.8333333333333334 + + dit_config["extra_h_extrapolation_ratio"] = 1.0 + dit_config["extra_w_extrapolation_ratio"] = 1.0 + dit_config["extra_t_extrapolation_ratio"] = 1.0 + dit_config["rope_enable_fps_modulation"] = False + + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/model_management.py b/comfy/model_management.py index 8ae5a5abb..054291432 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -295,6 +295,7 @@ except: pass +SUPPORT_FP8_OPS = args.supports_fp8_compute try: if is_amd(): try: @@ -305,9 +306,13 @@ try: logging.info("AMD arch: {}".format(arch)) logging.info("ROCm version: {}".format(rocm_version)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much - if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches + if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950 ENABLE_PYTORCH_ATTENTION = True + if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): + if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches + SUPPORT_FP8_OPS = True + except: pass @@ -328,7 +333,7 @@ except: pass try: - if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5: + if torch_version_numeric >= (2, 5): torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) except: logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp") @@ -1047,7 +1052,7 @@ def pytorch_attention_flash_attention(): global ENABLE_PYTORCH_ATTENTION if ENABLE_PYTORCH_ATTENTION: #TODO: more reliable way of checking for flash attention? - if is_nvidia(): #pytorch flash attention only works on Nvidia + if is_nvidia(): return True if is_intel_xpu(): return True @@ -1063,7 +1068,7 @@ def force_upcast_attention_dtype(): upcast = args.force_upcast_attention macos_version = mac_version() - if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS + if macos_version is not None and ((14, 5) <= macos_version): # black image bug on recent versions of macOS, I don't think it's ever getting fixed upcast = True if upcast: @@ -1262,7 +1267,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False def supports_fp8_compute(device=None): - if args.supports_fp8_compute: + if SUPPORT_FP8_OPS: return True if not is_nvidia(): @@ -1276,11 +1281,11 @@ def supports_fp8_compute(device=None): if props.minor < 9: return False - if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3): + if torch_version_numeric < (2, 3): return False if WINDOWS: - if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4): + if torch_version_numeric < (2, 4): return False return True diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b7cb12dfc..b1d6d4395 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -17,23 +17,26 @@ """ from __future__ import annotations -from typing import Optional, Callable -import torch + +import collections import copy import inspect import logging -import uuid -import collections import math +import uuid +from typing import Callable, Optional + +import torch -import comfy.utils import comfy.float -import comfy.model_management -import comfy.lora import comfy.hooks +import comfy.lora +import comfy.model_management import comfy.patcher_extension -from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection +import comfy.utils from comfy.comfy_types import UnetWrapperFunction +from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP + def string_to_seed(data): crc = 0xFFFFFFFF diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 7e7291476..b240b7f29 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -77,6 +77,25 @@ class IMG_TO_IMG(X0): def calculate_input(self, sigma, noise): return noise +class COSMOS_RFLOW: + def calculate_input(self, sigma, noise): + sigma = (sigma / (sigma + 1)) + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + return noise * (1.0 - sigma) + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = (sigma / (sigma + 1)) + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input * (1.0 - sigma) - model_output * sigma + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + noise = noise * sigma + noise += latent_image + return noise + + def inverse_noise_scaling(self, sigma, latent): + return latent class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None, zsnr=None): @@ -350,3 +369,15 @@ class ModelSamplingFlux(torch.nn.Module): if percent >= 1.0: return 0.0 return flux_time_shift(self.shift, 1.0, 1.0 - percent) + + +class ModelSamplingCosmosRFlow(ModelSamplingContinuousEDM): + def timestep(self, sigma): + return sigma / (sigma + 1) + + def sigma(self, timestep): + sigma_max = self.sigma_max + if timestep >= (sigma_max / (sigma_max + 1)): + return sigma_max + + return timestep / (1 - timestep) diff --git a/comfy/sd.py b/comfy/sd.py index e98a3aa87..cd13ab5f0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1081,7 +1081,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format +def load_diffusion_model_state_dict(sd, model_options={}): + """ + Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. + + Args: + sd (dict): State dictionary containing model weights and configuration + model_options (dict, optional): Additional options for model loading. Supports: + - dtype: Override model data type + - custom_operations: Custom model operations + - fp8_optimizations: Enable FP8 optimizations + + Returns: + ModelPatcher: A wrapped model instance that handles device management and weight loading. + Returns None if the model configuration cannot be detected. + + The function: + 1. Detects and handles different model formats (regular, diffusers, mmdit) + 2. Configures model dtype based on parameters and device capabilities + 3. Handles weight conversion and device placement + 4. Manages model optimization settings + 5. Loads weights and returns a device-managed model instance + """ dtype = model_options.get("dtype", None) #Allow loading unets from checkpoint files diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index ac61babe9..1b69a4103 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -462,7 +462,7 @@ class SDTokenizer: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length) - self.min_length = min_length + self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length) self.end_token = None self.min_padding = min_padding diff --git a/comfy/supported_models.py b/comfy/supported_models.py index efe2e6b8f..19f25e337 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -908,6 +908,48 @@ class CosmosI2V(CosmosT2V): out = model_base.CosmosVideo(self, image_to_video=True, device=device) return out +class CosmosT2IPredict2(supported_models_base.BASE): + unet_config = { + "image_model": "cosmos_predict2", + "in_channels": 16, + } + + sampling_settings = { + "sigma_data": 1.0, + "sigma_max": 80.0, + "sigma_min": 0.002, + } + + unet_extra_config = {} + latent_format = latent_formats.Wan21 + + memory_usage_factor = 1.0 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.CosmosPredict2(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect)) + +class CosmosI2VPredict2(CosmosT2IPredict2): + unet_config = { + "image_model": "cosmos_predict2", + "in_channels": 17, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.CosmosPredict2(self, image_to_video=True, device=device) + return out + class Lumina2(supported_models_base.BASE): unet_config = { "image_model": "lumina2", @@ -1139,6 +1181,6 @@ class ACEStep(supported_models_base.BASE): def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep] models += [SVD_img2vid] diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py index d2a1d0151..560b82be3 100644 --- a/comfy/weight_adapter/__init__.py +++ b/comfy/weight_adapter/__init__.py @@ -1,4 +1,4 @@ -from .base import WeightAdapterBase +from .base import WeightAdapterBase, WeightAdapterTrainBase from .lora import LoRAAdapter from .loha import LoHaAdapter from .lokr import LoKrAdapter @@ -15,3 +15,9 @@ adapters: list[type[WeightAdapterBase]] = [ OFTAdapter, BOFTAdapter, ] + +__all__ = [ + "WeightAdapterBase", + "WeightAdapterTrainBase", + "adapters" +] + [a.__name__ for a in adapters] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index 29873519d..b5c7db423 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -12,12 +12,20 @@ class WeightAdapterBase: weights: list[torch.Tensor] @classmethod - def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]: + def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]: raise NotImplementedError def to_train(self) -> "WeightAdapterTrainBase": raise NotImplementedError + @classmethod + def create_train(cls, weight, *args) -> "WeightAdapterTrainBase": + """ + weight: The original weight tensor to be modified. + *args: Additional arguments for configuration, such as rank, alpha etc. + """ + raise NotImplementedError + def calculate_weight( self, weight, @@ -33,10 +41,22 @@ class WeightAdapterBase: class WeightAdapterTrainBase(nn.Module): + # We follow the scheme of PR #7032 def __init__(self): super().__init__() - # [TODO] Collaborate with LoRA training PR #7032 + def __call__(self, w): + """ + w: The original weight tensor to be modified. + """ + raise NotImplementedError + + def passive_memory_usage(self): + raise NotImplementedError("passive_memory_usage is not implemented") + + def move_to(self, device): + self.to(device) + return self.passive_memory_usage() def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): @@ -102,3 +122,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten padded_tensor[new_slices] = tensor[orig_slices] return padded_tensor + + +def tucker_weight_from_conv(up, down, mid): + up = up.reshape(up.size(0), up.size(1)) + down = down.reshape(down.size(0), down.size(1)) + return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down) + + +def tucker_weight(wa, wb, t): + temp = torch.einsum("i j ..., j r -> i r ...", t, wb) + return torch.einsum("i j ..., i r -> r j ...", temp, wa) diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index b2e623924..729dbd9e6 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -3,7 +3,56 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape +from .base import ( + WeightAdapterBase, + WeightAdapterTrainBase, + weight_decompose, + pad_tensor_to_shape, + tucker_weight_from_conv, +) + + +class LoraDiff(WeightAdapterTrainBase): + def __init__(self, weights): + super().__init__() + mat1, mat2, alpha, mid, dora_scale, reshape = weights + out_dim, rank = mat1.shape[0], mat1.shape[1] + rank, in_dim = mat2.shape[0], mat2.shape[1] + if mid is not None: + convdim = mid.ndim - 2 + layer = ( + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d + )[convdim] + else: + layer = torch.nn.Linear + self.lora_up = layer(rank, out_dim, bias=False) + self.lora_down = layer(in_dim, rank, bias=False) + self.lora_up.weight.data.copy_(mat1) + self.lora_down.weight.data.copy_(mat2) + if mid is not None: + self.lora_mid = layer(mid, rank, bias=False) + self.lora_mid.weight.data.copy_(mid) + else: + self.lora_mid = None + self.rank = rank + self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) + + def __call__(self, w): + org_dtype = w.dtype + if self.lora_mid is None: + diff = self.lora_up.weight @ self.lora_down.weight + else: + diff = tucker_weight_from_conv( + self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight + ) + scale = self.alpha / self.rank + weight = w + scale * diff.reshape(w.shape) + return weight.to(org_dtype) + + def passive_memory_usage(self): + return sum(param.numel() * param.element_size() for param in self.parameters()) class LoRAAdapter(WeightAdapterBase): @@ -13,6 +62,21 @@ class LoRAAdapter(WeightAdapterBase): self.loaded_keys = loaded_keys self.weights = weights + @classmethod + def create_train(cls, weight, rank=1, alpha=1.0): + out_dim = weight.shape[0] + in_dim = weight.shape[1:].numel() + mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) + mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) + torch.nn.init.constant_(mat2, 0.0) + return LoraDiff( + (mat1, mat2, alpha, None, None, None) + ) + + def to_train(self): + return LoraDiff(self.weights) + @classmethod def load( cls, diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 010564704..d93fbd778 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -346,20 +346,6 @@ class FluxKontextProImageNode(ComfyNodeABC): }, } - @classmethod - def VALIDATE_INPUTS(cls, aspect_ratio: str): - try: - validate_aspect_ratio( - aspect_ratio, - minimum_ratio=cls.MINIMUM_RATIO, - maximum_ratio=cls.MAXIMUM_RATIO, - minimum_ratio_str=cls.MINIMUM_RATIO_STR, - maximum_ratio_str=cls.MAXIMUM_RATIO_STR, - ) - except Exception as e: - return str(e) - return True - RETURN_TYPES = (IO.IMAGE,) DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value FUNCTION = "api_call" @@ -380,6 +366,13 @@ class FluxKontextProImageNode(ComfyNodeABC): unique_id: Union[str, None] = None, **kwargs, ): + aspect_ratio = validate_aspect_ratio( + aspect_ratio, + minimum_ratio=self.MINIMUM_RATIO, + maximum_ratio=self.MAXIMUM_RATIO, + minimum_ratio_str=self.MINIMUM_RATIO_STR, + maximum_ratio_str=self.MAXIMUM_RATIO_STR, + ) if input_image is None: validate_string(prompt, strip_whitespace=False) operation = SynchronousOperation( @@ -395,13 +388,7 @@ class FluxKontextProImageNode(ComfyNodeABC): guidance=round(guidance, 1), steps=steps, seed=seed, - aspect_ratio=validate_aspect_ratio( - aspect_ratio, - minimum_ratio=self.MINIMUM_RATIO, - maximum_ratio=self.MAXIMUM_RATIO, - minimum_ratio_str=self.MINIMUM_RATIO_STR, - maximum_ratio_str=self.MAXIMUM_RATIO_STR, - ), + aspect_ratio=aspect_ratio, input_image=( input_image if input_image is None diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index b1cbf511d..b8487355f 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -324,7 +324,7 @@ class IdeogramV1(ComfyNodeABC): RETURN_TYPES = (IO.IMAGE,) FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram/v1" + CATEGORY = "api node/image/Ideogram" DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True @@ -483,7 +483,7 @@ class IdeogramV2(ComfyNodeABC): RETURN_TYPES = (IO.IMAGE,) FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram/v2" + CATEGORY = "api node/image/Ideogram" DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True @@ -649,7 +649,7 @@ class IdeogramV3(ComfyNodeABC): RETURN_TYPES = (IO.IMAGE,) FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram/v3" + CATEGORY = "api node/image/Ideogram" DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True diff --git a/comfy_config/types.py b/comfy_config/types.py index 611982083..5222cc59b 100644 --- a/comfy_config/types.py +++ b/comfy_config/types.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict from typing import List, Optional @@ -50,6 +50,7 @@ class ComfyConfig(BaseModel): icon: str = Field(default="", alias="Icon") models: List[Model] = Field(default_factory=list, alias="Models") includes: List[str] = Field(default_factory=list) + web: Optional[str] = None class License(BaseModel): @@ -66,6 +67,18 @@ class ProjectConfig(BaseModel): license: License = Field(default_factory=License) urls: URLs = Field(default_factory=URLs) + @field_validator('license', mode='before') + @classmethod + def validate_license(cls, v): + if isinstance(v, str): + return License(text=v) + elif isinstance(v, dict): + return License(**v) + elif isinstance(v, License): + return v + else: + return License() + class PyProjectConfig(BaseModel): project: ProjectConfig = Field(default_factory=ProjectConfig) @@ -77,4 +90,4 @@ class PyProjectSettings(BaseSettings): tool: dict = Field(default_factory=dict) - model_config = SettingsConfigDict() + model_config = SettingsConfigDict(extra='allow') diff --git a/comfy_extras/nodes_cosmos.py b/comfy_extras/nodes_cosmos.py index bd35ddb06..4f4960551 100644 --- a/comfy_extras/nodes_cosmos.py +++ b/comfy_extras/nodes_cosmos.py @@ -2,6 +2,7 @@ import nodes import torch import comfy.model_management import comfy.utils +import comfy.latent_formats class EmptyCosmosLatentVideo: @@ -75,8 +76,53 @@ class CosmosImageToVideoLatent: out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) return (out_latent,) +class CosmosPredict2ImageToVideoLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": {"vae": ("VAE", ), + "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"start_image": ("IMAGE", ), + "end_image": ("IMAGE", ), + }} + + + RETURN_TYPES = ("LATENT",) + FUNCTION = "encode" + + CATEGORY = "conditioning/inpaint" + + def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None): + latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is None and end_image is None: + out_latent = {} + out_latent["samples"] = latent + return (out_latent,) + + mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) + + if start_image is not None: + latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1) + latent[:, :, :latent_temp.shape[-3]] = latent_temp + mask[:, :, :latent_temp.shape[-3]] *= 0.0 + + if end_image is not None: + latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0) + latent[:, :, -latent_temp.shape[-3]:] = latent_temp + mask[:, :, -latent_temp.shape[-3]:] *= 0.0 + + out_latent = {} + latent_format = comfy.latent_formats.Wan21() + latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask) + out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) + out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) + return (out_latent,) NODE_CLASS_MAPPINGS = { "EmptyCosmosLatentVideo": EmptyCosmosLatentVideo, "CosmosImageToVideoLatent": CosmosImageToVideoLatent, + "CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent, } diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 58b29f9a9..b1e0d4666 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -505,8 +505,8 @@ class GetImageSize: } } - RETURN_TYPES = (IO.INT, IO.INT) - RETURN_NAMES = ("width", "height") + RETURN_TYPES = (IO.INT, IO.INT, IO.INT) + RETURN_NAMES = ("width", "height", "batch_size") FUNCTION = "get_size" CATEGORY = "image" @@ -515,12 +515,13 @@ class GetImageSize: def get_size(self, image, unique_id=None) -> tuple[int, int]: height = image.shape[1] width = image.shape[2] + batch_size = image.shape[0] # Send progress text to display size on the node if unique_id: - PromptServer.instance.send_progress_text(f"width: {width}, height: {height}", unique_id) + PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id) - return width, height + return width, height, batch_size NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 71a652ffa..ae5d2c563 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -189,7 +189,7 @@ class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps"],), + "sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps", "cosmos_rflow"],), "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), }} @@ -202,6 +202,7 @@ class ModelSamplingContinuousEDM: def patch(self, model, sampling, sigma_max, sigma_min): m = model.clone() + sampling_base = comfy.model_sampling.ModelSamplingContinuousEDM latent_format = None sigma_data = 1.0 if sampling == "eps": @@ -215,8 +216,11 @@ class ModelSamplingContinuousEDM: sampling_type = comfy.model_sampling.EDM sigma_data = 0.5 latent_format = comfy.latent_formats.SDXL_Playground_2_5() + elif sampling == "cosmos_rflow": + sampling_type = comfy.model_sampling.COSMOS_RFLOW + sampling_base = comfy.model_sampling.ModelSamplingCosmosRFlow - class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): + class ModelSamplingAdvanced(sampling_base, sampling_type): pass model_sampling = ModelSamplingAdvanced(model.model.model_config) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py new file mode 100644 index 000000000..fbff01010 --- /dev/null +++ b/comfy_extras/nodes_train.py @@ -0,0 +1,709 @@ +import datetime +import json +import logging +import os + +import numpy as np +import safetensors +import torch +from PIL import Image, ImageDraw, ImageFont +from PIL.PngImagePlugin import PngInfo +import torch.utils.checkpoint +import tqdm + +import comfy.samplers +import comfy.sd +import comfy.utils +import comfy.model_management +import comfy_extras.nodes_custom_sampler +import folder_paths +import node_helpers +from comfy.cli_args import args +from comfy.comfy_types.node_typing import IO +from comfy.weight_adapter import adapters + + +class TrainSampler(comfy.samplers.Sampler): + + def __init__(self, loss_fn, optimizer, loss_callback=None): + self.loss_fn = loss_fn + self.optimizer = optimizer + self.loss_callback = loss_callback + + def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + self.optimizer.zero_grad() + noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False) + latent = model_wrap.inner_model.model_sampling.noise_scaling( + torch.zeros_like(sigmas), + torch.zeros_like(noise, requires_grad=True), + latent_image, + False + ) + + # Ensure model is in training mode and computing gradients + # x0 pred + denoised = model_wrap(noise, sigmas, **extra_args) + try: + loss = self.loss_fn(denoised, latent.clone()) + except RuntimeError as e: + if "does not require grad and does not have a grad_fn" in str(e): + logging.info("WARNING: This is likely due to the model is loaded in inference mode.") + loss.backward() + if self.loss_callback: + self.loss_callback(loss.item()) + + self.optimizer.step() + # torch.cuda.memory._dump_snapshot("trainn.pickle") + # torch.cuda.memory._record_memory_history(enabled=None) + return torch.zeros_like(latent_image) + + +class BiasDiff(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.bias = bias + + def __call__(self, b): + org_dtype = b.dtype + return (b.to(self.bias) + self.bias).to(org_dtype) + + def passive_memory_usage(self): + return self.bias.nelement() * self.bias.element_size() + + def move_to(self, device): + self.to(device=device) + return self.passive_memory_usage() + + +def load_and_process_images(image_files, input_dir, resize_method="None"): + """Utility function to load and process a list of images. + + Args: + image_files: List of image filenames + input_dir: Base directory containing the images + resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") + + Returns: + torch.Tensor: Batch of processed images + """ + if not image_files: + raise ValueError("No valid images found in input") + + output_images = [] + w, h = None, None + + for file in image_files: + image_path = os.path.join(input_dir, file) + img = node_helpers.pillow(Image.open, image_path) + + if img.mode == "I": + img = img.point(lambda i: i * (1 / 255)) + img = img.convert("RGB") + + if w is None and h is None: + w, h = img.size[0], img.size[1] + + # Resize image to first image + if img.size[0] != w or img.size[1] != h: + if resize_method == "Stretch": + img = img.resize((w, h), Image.Resampling.LANCZOS) + elif resize_method == "Crop": + img = img.crop((0, 0, w, h)) + elif resize_method == "Pad": + img = img.resize((w, h), Image.Resampling.LANCZOS) + elif resize_method == "None": + raise ValueError( + "Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images." + ) + + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] + output_images.append(img_tensor) + + return torch.cat(output_images, dim=0) + + +class LoadImageSetNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "images": ( + [ + f + for f in os.listdir(folder_paths.get_input_directory()) + if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff")) + ], + {"image_upload": True, "allow_batch": True}, + ) + }, + "optional": { + "resize_method": ( + ["None", "Stretch", "Crop", "Pad"], + {"default": "None"}, + ), + }, + } + + INPUT_IS_LIST = True + RETURN_TYPES = ("IMAGE",) + FUNCTION = "load_images" + CATEGORY = "loaders" + EXPERIMENTAL = True + DESCRIPTION = "Loads a batch of images from a directory for training." + + @classmethod + def VALIDATE_INPUTS(s, images, resize_method): + filenames = images[0] if isinstance(images[0], list) else images + + for image in filenames: + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) + return True + + def load_images(self, input_files, resize_method): + input_dir = folder_paths.get_input_directory() + valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"] + image_files = [ + f + for f in input_files + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, input_dir, resize_method) + return (output_tensor,) + + +class LoadImageSetFromFolderNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}) + }, + "optional": { + "resize_method": ( + ["None", "Stretch", "Crop", "Pad"], + {"default": "None"}, + ), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "load_images" + CATEGORY = "loaders" + EXPERIMENTAL = True + DESCRIPTION = "Loads a batch of images from a directory for training." + + def load_images(self, folder, resize_method): + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + image_files = [ + f + for f in os.listdir(sub_input_dir) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method) + return (output_tensor,) + + +def draw_loss_graph(loss_map, steps): + width, height = 500, 300 + img = Image.new("RGB", (width, height), "white") + draw = ImageDraw.Draw(img) + + min_loss, max_loss = min(loss_map.values()), max(loss_map.values()) + scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()] + + prev_point = (0, height - int(scaled_loss[0] * height)) + for i, l in enumerate(scaled_loss[1:], start=1): + x = int(i / (steps - 1) * width) + y = height - int(l * height) + draw.line([prev_point, (x, y)], fill="blue", width=2) + prev_point = (x, y) + + return img + + +def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None): + if result is None: + result = [] + elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): + result.append(model) + logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") + return result + name = name or "root" + for next_name, child in model.named_children(): + find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}") + return result + + +def patch(m): + if not hasattr(m, "forward"): + return + org_forward = m.forward + def fwd(args, kwargs): + return org_forward(*args, **kwargs) + def checkpointing_fwd(*args, **kwargs): + return torch.utils.checkpoint.checkpoint( + fwd, args, kwargs, use_reentrant=False + ) + m.org_forward = org_forward + m.forward = checkpointing_fwd + + +def unpatch(m): + if hasattr(m, "org_forward"): + m.forward = m.org_forward + del m.org_forward + + +class TrainLoraNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}), + "latents": ( + "LATENT", + { + "tooltip": "The Latents to use for training, serve as dataset/input of the model." + }, + ), + "positive": ( + IO.CONDITIONING, + {"tooltip": "The positive conditioning to use for training."}, + ), + "batch_size": ( + IO.INT, + { + "default": 1, + "min": 1, + "max": 10000, + "step": 1, + "tooltip": "The batch size to use for training.", + }, + ), + "steps": ( + IO.INT, + { + "default": 16, + "min": 1, + "max": 100000, + "tooltip": "The number of steps to train the LoRA for.", + }, + ), + "learning_rate": ( + IO.FLOAT, + { + "default": 0.0005, + "min": 0.0000001, + "max": 1.0, + "step": 0.000001, + "tooltip": "The learning rate to use for training.", + }, + ), + "rank": ( + IO.INT, + { + "default": 8, + "min": 1, + "max": 128, + "tooltip": "The rank of the LoRA layers.", + }, + ), + "optimizer": ( + ["AdamW", "Adam", "SGD", "RMSprop"], + { + "default": "AdamW", + "tooltip": "The optimizer to use for training.", + }, + ), + "loss_function": ( + ["MSE", "L1", "Huber", "SmoothL1"], + { + "default": "MSE", + "tooltip": "The loss function to use for training.", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", + }, + ), + "training_dtype": ( + ["bf16", "fp32"], + {"default": "bf16", "tooltip": "The dtype to use for training."}, + ), + "lora_dtype": ( + ["bf16", "fp32"], + {"default": "bf16", "tooltip": "The dtype to use for lora."}, + ), + "existing_lora": ( + folder_paths.get_filename_list("loras") + ["[None]"], + { + "default": "[None]", + "tooltip": "The existing LoRA to append to. Set to None for new LoRA.", + }, + ), + }, + } + + RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT) + RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps") + FUNCTION = "train" + CATEGORY = "training" + EXPERIMENTAL = True + + def train( + self, + model, + latents, + positive, + batch_size, + steps, + learning_rate, + rank, + optimizer, + loss_function, + seed, + training_dtype, + lora_dtype, + existing_lora, + ): + mp = model.clone() + dtype = node_helpers.string_to_torch_dtype(training_dtype) + lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) + mp.set_model_compute_dtype(dtype) + + latents = latents["samples"].to(dtype) + num_images = latents.shape[0] + + with torch.inference_mode(False): + lora_sd = {} + generator = torch.Generator() + generator.manual_seed(seed) + + # Load existing LoRA weights if provided + existing_weights = {} + existing_steps = 0 + if existing_lora != "[None]": + lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora) + # Extract steps from filename like "trained_lora_10_steps_20250225_203716" + existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1]) + if lora_path: + existing_weights = comfy.utils.load_torch_file(lora_path) + + all_weight_adapters = [] + for n, m in mp.model.named_modules(): + if hasattr(m, "weight_function"): + if m.weight is not None: + key = "{}.weight".format(n) + shape = m.weight.shape + if len(shape) >= 2: + alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) + dora_scale = existing_weights.get( + f"{key}.dora_scale", None + ) + for adapter_cls in adapters: + existing_adapter = adapter_cls.load( + n, existing_weights, alpha, dora_scale + ) + if existing_adapter is not None: + break + else: + # If no existing adapter found, use LoRA + # We will add algo option in the future + existing_adapter = None + adapter_cls = adapters[0] + + if existing_adapter is not None: + train_adapter = existing_adapter.to_train().to(lora_dtype) + else: + # Use LoRA with alpha=1.0 by default + train_adapter = adapter_cls.create_train( + m.weight, rank=rank, alpha=1.0 + ).to(lora_dtype) + for name, parameter in train_adapter.named_parameters(): + lora_sd[f"{n}.{name}"] = parameter + + mp.add_weight_wrapper(key, train_adapter) + all_weight_adapters.append(train_adapter) + else: + diff = torch.nn.Parameter( + torch.zeros( + m.weight.shape, dtype=lora_dtype, requires_grad=True + ) + ) + diff_module = BiasDiff(diff) + mp.add_weight_wrapper(key, BiasDiff(diff)) + all_weight_adapters.append(diff_module) + lora_sd["{}.diff".format(n)] = diff + if hasattr(m, "bias") and m.bias is not None: + key = "{}.bias".format(n) + bias = torch.nn.Parameter( + torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True) + ) + bias_module = BiasDiff(bias) + lora_sd["{}.diff_b".format(n)] = bias + mp.add_weight_wrapper(key, BiasDiff(bias)) + all_weight_adapters.append(bias_module) + + if optimizer == "Adam": + optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate) + elif optimizer == "AdamW": + optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate) + elif optimizer == "SGD": + optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate) + elif optimizer == "RMSprop": + optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate) + + # Setup loss function based on selection + if loss_function == "MSE": + criterion = torch.nn.MSELoss() + elif loss_function == "L1": + criterion = torch.nn.L1Loss() + elif loss_function == "Huber": + criterion = torch.nn.HuberLoss() + elif loss_function == "SmoothL1": + criterion = torch.nn.SmoothL1Loss() + + # setup models + for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): + patch(m) + comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) + + # Setup sampler and guider like in test script + loss_map = {"loss": []} + def loss_callback(loss): + loss_map["loss"].append(loss) + pbar.set_postfix({"loss": f"{loss:.4f}"}) + train_sampler = TrainSampler( + criterion, optimizer, loss_callback=loss_callback + ) + guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) + guider.set_conds(positive) # Set conditioning from input + ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced() + + # yoland: this currently resize to the first image in the dataset + + # Training loop + torch.cuda.empty_cache() + try: + for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): + # Generate random sigma + sigma = mp.model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + sigma = torch.tensor([sigma]) + + noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed) + + indices = torch.randperm(num_images)[:batch_size] + ss.sample( + noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()} + ) + finally: + for m in mp.model.modules(): + unpatch(m) + del ss, train_sampler, optimizer + torch.cuda.empty_cache() + + for adapter in all_weight_adapters: + adapter.requires_grad_(False) + + for param in lora_sd: + lora_sd[param] = lora_sd[param].to(lora_dtype) + + return (mp, lora_sd, loss_map, steps + existing_steps) + + +class LoraModelLoader: + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), + "lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), + } + } + + RETURN_TYPES = ("MODEL",) + OUTPUT_TOOLTIPS = ("The modified diffusion model.",) + FUNCTION = "load_lora_model" + + CATEGORY = "loaders" + DESCRIPTION = "Load Trained LoRA weights from Train LoRA node." + EXPERIMENTAL = True + + def load_lora_model(self, model, lora, strength_model): + if strength_model == 0: + return (model, ) + + model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) + return (model_lora, ) + + +class SaveLoRA: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "lora": ( + IO.LORA_MODEL, + { + "tooltip": "The LoRA model to save. Do not use the model with LoRA layers." + }, + ), + "prefix": ( + "STRING", + { + "default": "loras/ComfyUI_trained_lora", + "tooltip": "The prefix to use for the saved LoRA file.", + }, + ), + }, + "optional": { + "steps": ( + IO.INT, + { + "forceInput": True, + "tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.", + }, + ), + }, + } + + RETURN_TYPES = () + FUNCTION = "save" + CATEGORY = "loaders" + EXPERIMENTAL = True + OUTPUT_NODE = True + + def save(self, lora, prefix, steps=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir) + if steps is None: + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + else: + output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + safetensors.torch.save_file(lora, output_checkpoint) + return {} + + +class LossGraphNode: + def __init__(self): + self.output_dir = folder_paths.get_temp_directory() + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "loss": (IO.LOSS_MAP, {"default": {}}), + "filename_prefix": (IO.STRING, {"default": "loss_graph"}), + }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "plot_loss" + OUTPUT_NODE = True + CATEGORY = "training" + EXPERIMENTAL = True + DESCRIPTION = "Plots the loss graph and saves it to the output directory." + + def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None): + loss_values = loss["loss"] + width, height = 800, 480 + margin = 40 + + img = Image.new( + "RGB", (width + margin, height + margin), "white" + ) # Extend canvas + draw = ImageDraw.Draw(img) + + min_loss, max_loss = min(loss_values), max(loss_values) + scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values] + + steps = len(loss_values) + + prev_point = (margin, height - int(scaled_loss[0] * height)) + for i, l in enumerate(scaled_loss[1:], start=1): + x = margin + int(i / steps * width) # Scale X properly + y = height - int(l * height) + draw.line([prev_point, (x, y)], fill="blue", width=2) + prev_point = (x, y) + + draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis + draw.line( + [(margin, height), (width + margin, height)], fill="black", width=2 + ) # X-axis + + font = None + try: + font = ImageFont.truetype("arial.ttf", 12) + except IOError: + font = ImageFont.load_default() + + # Add axis labels + draw.text((5, height // 2), "Loss", font=font, fill="black") + draw.text((width // 2, height + 10), "Steps", font=font, fill="black") + + # Add min/max loss values + draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black") + draw.text( + (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" + ) + + metadata = None + if not args.disable_metadata: + metadata = PngInfo() + if prompt is not None: + metadata.add_text("prompt", json.dumps(prompt)) + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata.add_text(x, json.dumps(extra_pnginfo[x])) + + date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + img.save( + os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"), + pnginfo=metadata, + ) + return { + "ui": { + "images": [ + { + "filename": f"{filename_prefix}_{date}.png", + "subfolder": "", + "type": "temp", + } + ] + } + } + + +NODE_CLASS_MAPPINGS = { + "TrainLoraNode": TrainLoraNode, + "SaveLoRANode": SaveLoRA, + "LoraModelLoader": LoraModelLoader, + "LoadImageSetFromFolderNode": LoadImageSetFromFolderNode, + "LossGraphNode": LossGraphNode, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "TrainLoraNode": "Train LoRA", + "SaveLoRANode": "Save LoRA Weights", + "LoraModelLoader": "Load LoRA Model", + "LoadImageSetFromFolderNode": "Load Image Dataset from Folder", + "LossGraphNode": "Plot Loss Graph", +} diff --git a/comfy_extras/nodes_webcam.py b/comfy_extras/nodes_webcam.py index 062b15cf8..5bf80b4c6 100644 --- a/comfy_extras/nodes_webcam.py +++ b/comfy_extras/nodes_webcam.py @@ -23,6 +23,10 @@ class WebcamCapture(nodes.LoadImage): def load_capture(self, image, **kwargs): return super().load_image(folder_paths.get_annotated_filepath(image)) + @classmethod + def IS_CHANGED(cls, image, width, height, capture_on_queue): + return super().IS_CHANGED(image) + NODE_CLASS_MAPPINGS = { "WebcamCapture": WebcamCapture, diff --git a/comfyui_version.py b/comfyui_version.py index f742410b1..fedd3466f 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.39" +__version__ = "0.3.41" diff --git a/execution.py b/execution.py index f994c75c2..6b885e476 100644 --- a/execution.py +++ b/execution.py @@ -1,24 +1,36 @@ -import sys import copy -import logging -import threading import heapq +import inspect +import logging +import sys +import threading import time import traceback from enum import Enum -import inspect from typing import List, Literal, NamedTuple, Optional import torch -import nodes import comfy.model_management -from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker -from comfy_execution.graph_utils import is_link, GraphBuilder -from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID +import nodes +from comfy_execution.caching import ( + CacheKeySetID, + CacheKeySetInputSignature, + DependencyAwareCache, + HierarchicalCache, + LRUCache, +) +from comfy_execution.graph import ( + DynamicPrompt, + ExecutionBlocker, + ExecutionList, + get_input_info, +) +from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.validation import validate_node_input from comfy_api.v3.io import NodeOutput, ComfyNodeV3, Hidden + class ExecutionResult(Enum): SUCCESS = 0 FAILURE = 1 diff --git a/folder_paths.py b/folder_paths.py index f0b3fd103..9ec952940 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -276,6 +276,9 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str]) def get_full_path(folder_name: str, filename: str) -> str | None: + """ + Get the full path of a file in a folder, has to be a file + """ global folder_names_and_paths folder_name = map_legacy(folder_name) if folder_name not in folder_names_and_paths: @@ -293,6 +296,9 @@ def get_full_path(folder_name: str, filename: str) -> str | None: def get_full_path_or_raise(folder_name: str, filename: str) -> str: + """ + Get the full path of a file in a folder, has to be a file + """ full_path = get_full_path(folder_name, filename) if full_path is None: raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.") @@ -394,3 +400,26 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im os.makedirs(full_output_folder, exist_ok=True) counter = 1 return full_output_folder, filename, counter, subfolder, filename_prefix + +def get_input_subfolders() -> list[str]: + """Returns a list of all subfolder paths in the input directory, recursively. + + Returns: + List of folder paths relative to the input directory, excluding the root directory + """ + input_dir = get_input_directory() + folders = [] + + try: + if not os.path.exists(input_dir): + return [] + + for root, dirs, _ in os.walk(input_dir): + rel_path = os.path.relpath(root, input_dir) + if rel_path != ".": # Only include non-root directories + # Normalize path separators to forward slashes + folders.append(rel_path.replace(os.sep, '/')) + + return sorted(folders) + except FileNotFoundError: + return [] diff --git a/main.py b/main.py index fb1f8d20b..c8c4194d4 100644 --- a/main.py +++ b/main.py @@ -17,7 +17,6 @@ if __name__ == "__main__": os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['DO_NOT_TRACK'] = '1' - setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) def apply_custom_paths(): @@ -238,6 +237,15 @@ def cleanup_temp(): shutil.rmtree(temp_dir, ignore_errors=True) +def setup_database(): + try: + from app.database.db import init_db, dependencies_available + if dependencies_available(): + init_db() + except Exception as e: + logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") + + def start_comfyui(asyncio_loop=None): """ Starts the ComfyUI server using the provided asyncio event loop or creates a new one. @@ -266,6 +274,7 @@ def start_comfyui(asyncio_loop=None): hook_breaker_ac10a0.restore_functions() cuda_malloc_warning() + setup_database() prompt_server.add_routes() hijack_progress(prompt_server) @@ -300,6 +309,9 @@ if __name__ == "__main__": logging.info("Python version: {}".format(sys.version)) logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) + if sys.version_info.major == 3 and sys.version_info.minor < 10: + logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") + event_loop, _, start_all_func = start_comfyui() try: x = start_all_func() diff --git a/nodes.py b/nodes.py index 022732bbf..2dfa8e8f1 100644 --- a/nodes.py +++ b/nodes.py @@ -2126,6 +2126,25 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir) + try: + from comfy_config import config_parser + + project_config = config_parser.extract_node_configuration(module_path) + + web_dir_name = project_config.tool_comfy.web + + if web_dir_name: + web_dir_path = os.path.join(module_path, web_dir_name) + + if os.path.isdir(web_dir_path): + project_name = project_config.project.name + + EXTENSION_WEB_DIRS[project_name] = web_dir_path + + logging.info("Automatically register web folder {} for {}".format(web_dir_name, project_name)) + except Exception as e: + logging.warning(f"Unable to parse pyproject.toml due to lack dependency pydantic-settings, please run 'pip install -r requirements.txt': {e}") + if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None: web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY"))) if os.path.isdir(web_dir): @@ -2225,6 +2244,7 @@ def init_builtin_extra_nodes(): "nodes_model_downscale.py", "nodes_images.py", "nodes_video_model.py", + "nodes_train.py", "nodes_sag.py", "nodes_perpneg.py", "nodes_stable3d.py", diff --git a/pyproject.toml b/pyproject.toml index 28a6158e0..c572ad4c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.39" +version = "0.3.41" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index 1c1ff54ac..336ec9d57 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -comfyui-frontend-package==1.21.6 -comfyui-workflow-templates==0.1.25 -comfyui-embedded-docs==0.2.0 +comfyui-frontend-package==1.21.7 +comfyui-workflow-templates==0.1.28 +comfyui-embedded-docs==0.2.2 torch torchsde torchvision @@ -18,6 +18,8 @@ Pillow scipy tqdm psutil +alembic +SQLAlchemy #non essential dependencies: kornia>=0.7.1 @@ -25,3 +27,4 @@ spandrel soundfile av>=14.2.0 pydantic~=2.0 +pydantic-settings~=2.0 diff --git a/server.py b/server.py index c5eb14484..1a135fca7 100644 --- a/server.py +++ b/server.py @@ -791,7 +791,7 @@ class PromptServer(): if hasattr(Image, 'Resampling'): resampling = Image.Resampling.BILINEAR else: - resampling = Image.ANTIALIAS + resampling = Image.Resampling.LANCZOS image = ImageOps.contain(image, (max_size, max_size), resampling) type_num = 1 diff --git a/tests-unit/folder_paths_test/misc_test.py b/tests-unit/folder_paths_test/misc_test.py new file mode 100644 index 000000000..fcf667453 --- /dev/null +++ b/tests-unit/folder_paths_test/misc_test.py @@ -0,0 +1,51 @@ +import pytest +import os +import tempfile +from folder_paths import get_input_subfolders, set_input_directory + +@pytest.fixture(scope="module") +def mock_folder_structure(): + with tempfile.TemporaryDirectory() as temp_dir: + # Create a nested folder structure + folders = [ + "folder1", + "folder1/subfolder1", + "folder1/subfolder2", + "folder2", + "folder2/deep", + "folder2/deep/nested", + "empty_folder" + ] + + # Create the folders + for folder in folders: + os.makedirs(os.path.join(temp_dir, folder)) + + # Add some files to test they're not included + with open(os.path.join(temp_dir, "root_file.txt"), "w") as f: + f.write("test") + with open(os.path.join(temp_dir, "folder1", "test.txt"), "w") as f: + f.write("test") + + set_input_directory(temp_dir) + yield temp_dir + + +def test_gets_all_folders(mock_folder_structure): + folders = get_input_subfolders() + expected = ["folder1", "folder1/subfolder1", "folder1/subfolder2", + "folder2", "folder2/deep", "folder2/deep/nested", "empty_folder"] + assert sorted(folders) == sorted(expected) + + +def test_handles_nonexistent_input_directory(): + with tempfile.TemporaryDirectory() as temp_dir: + nonexistent = os.path.join(temp_dir, "nonexistent") + set_input_directory(nonexistent) + assert get_input_subfolders() == [] + + +def test_empty_input_directory(): + with tempfile.TemporaryDirectory() as temp_dir: + set_input_directory(temp_dir) + assert get_input_subfolders() == [] # Empty since we don't include root diff --git a/utils/install_util.py b/utils/install_util.py new file mode 100644 index 000000000..0f59bcf91 --- /dev/null +++ b/utils/install_util.py @@ -0,0 +1,18 @@ +from pathlib import Path +import sys + +# The path to the requirements.txt file +requirements_path = Path(__file__).parents[1] / "requirements.txt" + + +def get_missing_requirements_message(): + """The warning message to display when a package is missing.""" + + extra = "" + if sys.flags.no_user_site: + extra = "-s " + return f""" +Please install the updated requirements.txt file by running: +{sys.executable} {extra}-m pip install -r {requirements_path} +If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem. +""".strip()