mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 16:26:39 +00:00
Merge branch 'master' into worksplit-multigpu
This commit is contained in:
commit
1ae98932f1
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -15,6 +15,14 @@ body:
|
||||
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
label: Custom Node Testing
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
|
8
.github/ISSUE_TEMPLATE/user-support.yml
vendored
8
.github/ISSUE_TEMPLATE/user-support.yml
vendored
@ -11,6 +11,14 @@ body:
|
||||
**2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics.
|
||||
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
- type: checkboxes
|
||||
id: custom-nodes-test
|
||||
attributes:
|
||||
label: Custom Node Testing
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Your question
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
[![Website][website-shield]][website-url]
|
||||
[![Dynamic JSON Badge][discord-shield]][discord-url]
|
||||
[![Twitter][twitter-shield]][twitter-url]
|
||||
[![Matrix][matrix-shield]][matrix-url]
|
||||
<br>
|
||||
[![][github-release-shield]][github-release-link]
|
||||
@ -20,6 +21,8 @@
|
||||
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
|
||||
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
|
||||
[discord-url]: https://www.comfy.org/discord
|
||||
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
|
||||
[twitter-url]: https://x.com/ComfyUI
|
||||
|
||||
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
|
||||
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
@ -62,12 +65,13 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||
- Video Models
|
||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- Audio Models
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
@ -95,7 +99,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: will never download anything.
|
||||
- Works fully offline: core will never download anything unless you want to.
|
||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
|
||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||
|
||||
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
|
84
alembic.ini
Normal file
84
alembic.ini
Normal file
@ -0,0 +1,84 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
# Use forward slashes (/) also on windows to provide an os agnostic path
|
||||
script_location = alembic_db
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to alembic_db/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
# version_path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
version_path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = sqlite:///user/comfyui.db
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
4
alembic_db/README.md
Normal file
4
alembic_db/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
## Generate new revision
|
||||
|
||||
1. Update models in `/app/database/models.py`
|
||||
2. Run `alembic revision --autogenerate -m "{your message}"`
|
64
alembic_db/env.py
Normal file
64
alembic_db/env.py
Normal file
@ -0,0 +1,64 @@
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
|
||||
from app.database.models import Base
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
28
alembic_db/script.py.mako
Normal file
28
alembic_db/script.py.mako
Normal file
@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
112
app/database/db.py
Normal file
112
app/database/db.py
Normal file
@ -0,0 +1,112 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from app.logger import log_startup_warning
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
from comfy.cli_args import args
|
||||
|
||||
_DB_AVAILABLE = False
|
||||
Session = None
|
||||
|
||||
|
||||
try:
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
log_startup_warning(
|
||||
f"""
|
||||
------------------------------------------------------------------------
|
||||
Error importing dependencies: {e}
|
||||
{get_missing_requirements_message()}
|
||||
This error is happening because ComfyUI now uses a local sqlite database.
|
||||
------------------------------------------------------------------------
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def dependencies_available():
|
||||
"""
|
||||
Temporary function to check if the dependencies are available
|
||||
"""
|
||||
return _DB_AVAILABLE
|
||||
|
||||
|
||||
def can_create_session():
|
||||
"""
|
||||
Temporary function to check if the database is available to create a session
|
||||
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
|
||||
"""
|
||||
return dependencies_available() and Session is not None
|
||||
|
||||
|
||||
def get_alembic_config():
|
||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||
|
||||
config = Config(config_path)
|
||||
config.set_main_option("script_location", scripts_path)
|
||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_db_path():
|
||||
url = args.database_url
|
||||
if url.startswith("sqlite:///"):
|
||||
return url.split("///")[1]
|
||||
else:
|
||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||
|
||||
|
||||
def init_db():
|
||||
db_url = args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
db_path = get_db_path()
|
||||
db_exists = os.path.exists(db_path)
|
||||
|
||||
config = get_alembic_config()
|
||||
|
||||
# Check if we need to upgrade
|
||||
engine = create_engine(db_url)
|
||||
conn = engine.connect()
|
||||
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
script = ScriptDirectory.from_config(config)
|
||||
target_rev = script.get_current_head()
|
||||
|
||||
if target_rev is None:
|
||||
logging.warning("No target revision found.")
|
||||
elif current_rev != target_rev:
|
||||
# Backup the database pre upgrade
|
||||
backup_path = db_path + ".bkp"
|
||||
if db_exists:
|
||||
shutil.copy(db_path, backup_path)
|
||||
else:
|
||||
backup_path = None
|
||||
|
||||
try:
|
||||
command.upgrade(config, target_rev)
|
||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
||||
except Exception as e:
|
||||
if backup_path:
|
||||
# Restore the database from backup if upgrade fails
|
||||
shutil.copy(backup_path, db_path)
|
||||
os.remove(backup_path)
|
||||
logging.exception("Error upgrading database: ")
|
||||
raise e
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
|
||||
|
||||
def create_session():
|
||||
return Session()
|
14
app/database/models.py
Normal file
14
app/database/models.py
Normal file
@ -0,0 +1,14 @@
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
fields = obj.__table__.columns.keys()
|
||||
return {
|
||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||
for field in fields
|
||||
if (val := getattr(obj, field))
|
||||
}
|
||||
|
||||
# TODO: Define models here
|
@ -16,26 +16,17 @@ from importlib.metadata import version
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from utils.install_util import get_missing_requirements_message, requirements_path
|
||||
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
import app.logger
|
||||
|
||||
# The path to the requirements.txt file
|
||||
req_path = Path(__file__).parents[1] / "requirements.txt"
|
||||
|
||||
|
||||
def frontend_install_warning_message():
|
||||
"""The warning message to display when the frontend version is not up to date."""
|
||||
|
||||
extra = ""
|
||||
if sys.flags.no_user_site:
|
||||
extra = "-s "
|
||||
return f"""
|
||||
Please install the updated requirements.txt file by running:
|
||||
{sys.executable} {extra}-m pip install -r {req_path}
|
||||
{get_missing_requirements_message()}
|
||||
|
||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||
|
||||
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
|
||||
""".strip()
|
||||
|
||||
|
||||
@ -48,7 +39,7 @@ def check_frontend_version():
|
||||
try:
|
||||
frontend_version_str = version("comfyui-frontend-package")
|
||||
frontend_version = parse_version(frontend_version_str)
|
||||
with open(req_path, "r", encoding="utf-8") as f:
|
||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||
if frontend_version < required_frontend:
|
||||
app.logger.log_startup_warning(
|
||||
@ -121,9 +112,22 @@ class FrontEndProvider:
|
||||
response.raise_for_status() # Raises an HTTPError if the response was an error
|
||||
return response.json()
|
||||
|
||||
@cached_property
|
||||
def latest_prerelease(self) -> Release:
|
||||
"""Get the latest pre-release version - even if it's older than the latest release"""
|
||||
release = [release for release in self.all_releases if release["prerelease"]]
|
||||
|
||||
if not release:
|
||||
raise ValueError("No pre-releases found")
|
||||
|
||||
# GitHub returns releases in reverse chronological order, so first is latest
|
||||
return release[0]
|
||||
|
||||
def get_release(self, version: str) -> Release:
|
||||
if version == "latest":
|
||||
return self.latest_release
|
||||
elif version == "prerelease":
|
||||
return self.latest_prerelease
|
||||
else:
|
||||
for release in self.all_releases:
|
||||
if release["tag_name"] in [version, f"v{version}"]:
|
||||
@ -230,7 +234,7 @@ comfyui-workflow-templates is not installed.
|
||||
Raises:
|
||||
argparse.ArgumentTypeError: If the version string is invalid.
|
||||
"""
|
||||
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
|
||||
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+[-._a-zA-Z0-9]*|latest|prerelease)$"
|
||||
match_result = re.match(VERSION_PATTERN, value)
|
||||
if match_result is None:
|
||||
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
|
||||
|
@ -203,6 +203,11 @@ parser.add_argument(
|
||||
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
||||
)
|
||||
|
||||
database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
|
@ -37,6 +37,8 @@ class IO(StrEnum):
|
||||
CONTROL_NET = "CONTROL_NET"
|
||||
VAE = "VAE"
|
||||
MODEL = "MODEL"
|
||||
LORA_MODEL = "LORA_MODEL"
|
||||
LOSS_MAP = "LOSS_MAP"
|
||||
CLIP_VISION = "CLIP_VISION"
|
||||
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
|
||||
STYLE_MODEL = "STYLE_MODEL"
|
||||
|
@ -433,8 +433,9 @@ class ControlLora(ControlNet):
|
||||
pass
|
||||
|
||||
for k in self.control_weights:
|
||||
if k not in {"lora_controlnet"}:
|
||||
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
|
||||
if (k not in {"lora_controlnet"}):
|
||||
if (k.endswith(".up") or k.endswith(".down") or k.endswith(".weight") or k.endswith(".bias")) and ("__" not in k):
|
||||
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
|
||||
|
||||
def copy(self):
|
||||
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
from scipy import integrate
|
||||
import torch
|
||||
@ -142,6 +143,33 @@ class BrownianTreeNoiseSampler:
|
||||
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||
|
||||
|
||||
def sigma_to_half_log_snr(sigma, model_sampling):
|
||||
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
|
||||
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
||||
# log((1 - t) / t) = log((1 - sigma) / sigma)
|
||||
return sigma.logit().neg()
|
||||
return sigma.log().neg()
|
||||
|
||||
|
||||
def half_log_snr_to_sigma(half_log_snr, model_sampling):
|
||||
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
|
||||
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
||||
# 1 / (1 + exp(half_log_snr))
|
||||
return half_log_snr.neg().sigmoid()
|
||||
return half_log_snr.neg().exp()
|
||||
|
||||
|
||||
def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
|
||||
"""Adjust the first sigma to avoid invalid logSNR."""
|
||||
if len(sigmas) <= 1:
|
||||
return sigmas
|
||||
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
||||
if sigmas[0] >= 1:
|
||||
sigmas = sigmas.clone()
|
||||
sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
|
||||
return sigmas
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
@ -753,6 +781,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
"""DPM-Solver++(2M) SDE."""
|
||||
@ -768,9 +797,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
old_denoised = None
|
||||
h_last = None
|
||||
h = None
|
||||
h, h_last = None, None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@ -781,26 +813,29 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
x = denoised
|
||||
else:
|
||||
# DPM-Solver++(2M) SDE
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
eta_h = eta * h
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
|
||||
|
||||
if old_denoised is not None:
|
||||
r = h_last / h
|
||||
if solver_type == 'heun':
|
||||
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
||||
x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised)
|
||||
elif solver_type == 'midpoint':
|
||||
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||
x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||
|
||||
if eta:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||
if eta > 0 and s_noise > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
||||
|
||||
old_denoised = denoised
|
||||
h_last = h
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""DPM-Solver++(3M) SDE."""
|
||||
@ -814,6 +849,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
denoised_1, denoised_2 = None, None
|
||||
h, h_1, h_2 = None, None, None
|
||||
|
||||
@ -825,13 +864,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
|
||||
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
|
||||
|
||||
if h_2 is not None:
|
||||
# DPM-Solver++(3M) SDE
|
||||
r0 = h_1 / h
|
||||
r1 = h_2 / h
|
||||
d1_0 = (denoised - denoised_1) / r0
|
||||
@ -840,20 +882,22 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
d2 = (d1_0 - d1_1) / (r0 + r1)
|
||||
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
||||
phi_3 = phi_2 / h_eta - 0.5
|
||||
x = x + phi_2 * d1 - phi_3 * d2
|
||||
x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2
|
||||
elif h_1 is not None:
|
||||
# DPM-Solver++(2M) SDE
|
||||
r = h_1 / h
|
||||
d = (denoised - denoised_1) / r
|
||||
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
||||
x = x + phi_2 * d
|
||||
x = x + (alpha_t * phi_2) * d
|
||||
|
||||
if eta:
|
||||
if eta > 0 and s_noise > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
||||
|
||||
denoised_1, denoised_2 = denoised, denoised_1
|
||||
h_1, h_2 = h, h_1
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if len(sigmas) <= 1:
|
||||
@ -863,6 +907,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
if len(sigmas) <= 1:
|
||||
@ -872,6 +917,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
if len(sigmas) <= 1:
|
||||
@ -1449,12 +1495,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
'''
|
||||
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
|
||||
Arxiv: https://arxiv.org/abs/2305.14267
|
||||
'''
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
@ -1462,6 +1508,11 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
@ -1469,80 +1520,96 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = t_next - t
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
s = t + r * h
|
||||
lambda_s_1 = lambda_s + r * h
|
||||
fac = 1 / (2 * r)
|
||||
sigma_s = s.neg().exp()
|
||||
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||
|
||||
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
# 0 < r < 1
|
||||
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
||||
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
|
||||
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
|
||||
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
|
||||
|
||||
# Step 1
|
||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
'''
|
||||
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
|
||||
Arxiv: https://arxiv.org/abs/2305.14267
|
||||
'''
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = t_next - t
|
||||
h_eta = h * (eta + 1)
|
||||
s_1 = t + r_1 * h
|
||||
s_2 = t + r_2 * h
|
||||
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
|
||||
|
||||
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
|
||||
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
||||
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
|
||||
# Step 1
|
||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
|
||||
arXiv: https://arxiv.org/abs/2305.14267
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
inject_noise = eta > 0 and s_noise > 0
|
||||
|
||||
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||
h = lambda_t - lambda_s
|
||||
h_eta = h * (eta + 1)
|
||||
lambda_s_1 = lambda_s + r_1 * h
|
||||
lambda_s_2 = lambda_s + r_2 * h
|
||||
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
|
||||
|
||||
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
||||
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
|
||||
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||
|
||||
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
||||
if inject_noise:
|
||||
# 0 < r_1 < r_2 < 1
|
||||
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
|
||||
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
|
||||
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||
|
||||
# Step 1
|
||||
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
||||
if inject_noise:
|
||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
||||
if inject_noise:
|
||||
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||
|
||||
# Step 3
|
||||
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
||||
if inject_noise:
|
||||
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
||||
return x
|
||||
|
@ -26,16 +26,6 @@ from torch import nn
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||
return t_out
|
||||
|
||||
|
||||
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
||||
if name == "I":
|
||||
return nn.Identity()
|
||||
|
@ -66,15 +66,16 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
|
||||
h_extrapolation_ratio: float = 1.0,
|
||||
w_extrapolation_ratio: float = 1.0,
|
||||
t_extrapolation_ratio: float = 1.0,
|
||||
enable_fps_modulation: bool = True,
|
||||
device=None,
|
||||
**kwargs, # used for compatibility with other positional embeddings; unused in this class
|
||||
):
|
||||
del kwargs
|
||||
super().__init__()
|
||||
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
|
||||
self.base_fps = base_fps
|
||||
self.max_h = len_h
|
||||
self.max_w = len_w
|
||||
self.enable_fps_modulation = enable_fps_modulation
|
||||
|
||||
dim = head_dim
|
||||
dim_h = dim // 6 * 2
|
||||
@ -132,21 +133,19 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
|
||||
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
|
||||
|
||||
B, T, H, W, _ = B_T_H_W_C
|
||||
seq = torch.arange(max(H, W, T), dtype=torch.float, device=device)
|
||||
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
|
||||
assert (
|
||||
uniform_fps or B == 1 or T == 1
|
||||
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
|
||||
assert (
|
||||
H <= self.max_h and W <= self.max_w
|
||||
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
|
||||
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
|
||||
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
|
||||
half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs)
|
||||
half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs)
|
||||
|
||||
# apply sequence scaling in temporal dimension
|
||||
if fps is None: # image case
|
||||
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
|
||||
if fps is None or self.enable_fps_modulation is False: # image case
|
||||
half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs)
|
||||
else:
|
||||
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
||||
half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
||||
|
||||
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
|
||||
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
|
||||
|
864
comfy/ldm/cosmos/predict2.py
Normal file
864
comfy/ldm/cosmos/predict2.py
Normal file
@ -0,0 +1,864 @@
|
||||
# original code from: https://github.com/nvidia-cosmos/cosmos-predict2
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
import logging
|
||||
from typing import Callable, Optional, Tuple
|
||||
import math
|
||||
|
||||
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
||||
from torchvision import transforms
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||
return t_out
|
||||
|
||||
|
||||
# ---------------------- Feed Forward Network -----------------------
|
||||
class GPT2FeedForward(nn.Module):
|
||||
def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.activation = nn.GELU()
|
||||
self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
|
||||
self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
|
||||
|
||||
self._layer_id = None
|
||||
self._dim = d_model
|
||||
self._hidden_dim = d_ff
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.layer1(x)
|
||||
|
||||
x = self.activation(x)
|
||||
x = self.layer2(x)
|
||||
return x
|
||||
|
||||
|
||||
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes multi-head attention using PyTorch's native implementation.
|
||||
|
||||
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
||||
It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product
|
||||
attention, and rearranges the output back to the original format.
|
||||
|
||||
The input tensor names use the following dimension conventions:
|
||||
|
||||
- B: batch size
|
||||
- S: sequence length
|
||||
- H: number of attention heads
|
||||
- D: head dimension
|
||||
|
||||
Args:
|
||||
q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)
|
||||
k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)
|
||||
v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)
|
||||
|
||||
Returns:
|
||||
Attention output tensor with shape (batch, seq_len, n_heads * head_dim)
|
||||
"""
|
||||
in_q_shape = q_B_S_H_D.shape
|
||||
in_k_shape = k_B_S_H_D.shape
|
||||
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
||||
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
A flexible attention module supporting both self-attention and cross-attention mechanisms.
|
||||
|
||||
This module implements a multi-head attention layer that can operate in either self-attention
|
||||
or cross-attention mode. The mode is determined by whether a context dimension is provided.
|
||||
The implementation uses scaled dot-product attention and supports optional bias terms and
|
||||
dropout regularization.
|
||||
|
||||
Args:
|
||||
query_dim (int): The dimensionality of the query vectors.
|
||||
context_dim (int, optional): The dimensionality of the context (key/value) vectors.
|
||||
If None, the module operates in self-attention mode using query_dim. Default: None
|
||||
n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8
|
||||
head_dim (int, optional): The dimension of each attention head. Default: 64
|
||||
dropout (float, optional): Dropout probability applied to the output. Default: 0.0
|
||||
qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd"
|
||||
backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine"
|
||||
|
||||
Examples:
|
||||
>>> # Self-attention with 512 dimensions and 8 heads
|
||||
>>> self_attn = Attention(query_dim=512)
|
||||
>>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim)
|
||||
>>> out = self_attn(x) # (32, 16, 512)
|
||||
|
||||
>>> # Cross-attention
|
||||
>>> cross_attn = Attention(query_dim=512, context_dim=256)
|
||||
>>> query = torch.randn(32, 16, 512)
|
||||
>>> context = torch.randn(32, 8, 256)
|
||||
>>> out = cross_attn(query, context) # (32, 16, 512)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
context_dim: Optional[int] = None,
|
||||
n_heads: int = 8,
|
||||
head_dim: int = 64,
|
||||
dropout: float = 0.0,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
logging.debug(
|
||||
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||
f"{n_heads} heads with a dimension of {head_dim}."
|
||||
)
|
||||
self.is_selfattn = context_dim is None # self attention
|
||||
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
inner_dim = head_dim * n_heads
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = head_dim
|
||||
self.query_dim = query_dim
|
||||
self.context_dim = context_dim
|
||||
|
||||
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||
self.v_norm = nn.Identity()
|
||||
|
||||
self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
|
||||
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
|
||||
|
||||
self.attn_op = torch_attention_op
|
||||
|
||||
self._query_dim = query_dim
|
||||
self._context_dim = context_dim
|
||||
self._inner_dim = inner_dim
|
||||
|
||||
def compute_qkv(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
rope_emb: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q = self.q_proj(x)
|
||||
context = x if context is None else context
|
||||
k = self.k_proj(context)
|
||||
v = self.v_proj(context)
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
def apply_norm_and_rotary_pos_emb(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
v = self.v_norm(v)
|
||||
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||
q = apply_rotary_pos_emb(q, rope_emb)
|
||||
k = apply_rotary_pos_emb(k, rope_emb)
|
||||
return q, k, v
|
||||
|
||||
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
result = self.attn_op(q, k, v) # [B, S, H, D]
|
||||
return self.output_dropout(self.output_proj(result))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
rope_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): The query tensor of shape [B, Mq, K]
|
||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||
"""
|
||||
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
||||
return self.compute_attention(q, k, v)
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
|
||||
assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
|
||||
timesteps = timesteps_B_T.flatten().float()
|
||||
half_dim = self.num_channels // 2
|
||||
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
exponent = exponent / (half_dim - 0.0)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
sin_emb = torch.sin(emb)
|
||||
cos_emb = torch.cos(emb)
|
||||
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
||||
|
||||
return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
logging.debug(
|
||||
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
||||
)
|
||||
self.in_dim = in_features
|
||||
self.out_dim = out_features
|
||||
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)
|
||||
self.activation = nn.SiLU()
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
if use_adaln_lora:
|
||||
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)
|
||||
else:
|
||||
self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
emb = self.linear_1(sample)
|
||||
emb = self.activation(emb)
|
||||
emb = self.linear_2(emb)
|
||||
|
||||
if self.use_adaln_lora:
|
||||
adaln_lora_B_T_3D = emb
|
||||
emb_B_T_D = sample
|
||||
else:
|
||||
adaln_lora_B_T_3D = None
|
||||
emb_B_T_D = emb
|
||||
|
||||
return emb_B_T_D, adaln_lora_B_T_3D
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
|
||||
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
|
||||
making it suitable for video and image processing tasks. It supports dividing the input into patches
|
||||
and embedding each patch into a vector of size `out_channels`.
|
||||
|
||||
Parameters:
|
||||
- spatial_patch_size (int): The size of each spatial patch.
|
||||
- temporal_patch_size (int): The size of each temporal patch.
|
||||
- in_channels (int): Number of input channels. Default: 3.
|
||||
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
|
||||
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spatial_patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 768,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.spatial_patch_size = spatial_patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
Rearrange(
|
||||
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
|
||||
r=temporal_patch_size,
|
||||
m=spatial_patch_size,
|
||||
n=spatial_patch_size,
|
||||
),
|
||||
operations.Linear(
|
||||
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype
|
||||
),
|
||||
)
|
||||
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the PatchEmbed module.
|
||||
|
||||
Parameters:
|
||||
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
|
||||
B is the batch size,
|
||||
C is the number of channels,
|
||||
T is the temporal dimension,
|
||||
H is the height, and
|
||||
W is the width of the input.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
|
||||
"""
|
||||
assert x.dim() == 5
|
||||
_, _, T, H, W = x.shape
|
||||
assert (
|
||||
H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
|
||||
), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
|
||||
assert T % self.temporal_patch_size == 0
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of video DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
spatial_patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
out_channels: int,
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.n_adaln_chunks = 2
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
self.adaln_lora_dim = adaln_lora_dim
|
||||
if use_adaln_lora:
|
||||
self.adaln_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
else:
|
||||
self.adaln_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_T_D: torch.Tensor,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.use_adaln_lora:
|
||||
assert adaln_lora_B_T_3D is not None
|
||||
shift_B_T_D, scale_B_T_D = (
|
||||
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
|
||||
).chunk(2, dim=-1)
|
||||
else:
|
||||
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
|
||||
|
||||
shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange(
|
||||
scale_B_T_D, "b t d -> b t 1 1 d"
|
||||
)
|
||||
|
||||
def _fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
_norm_layer: nn.Module,
|
||||
_scale_B_T_1_1_D: torch.Tensor,
|
||||
_shift_B_T_1_1_D: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
||||
|
||||
x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)
|
||||
x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
|
||||
return x_B_T_H_W_O
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""
|
||||
A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.
|
||||
Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.
|
||||
|
||||
Parameters:
|
||||
x_dim (int): Dimension of input features
|
||||
context_dim (int): Dimension of context features for cross-attention
|
||||
num_heads (int): Number of attention heads
|
||||
mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0
|
||||
use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False
|
||||
adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256
|
||||
|
||||
The block applies the following sequence:
|
||||
1. Self-attention with AdaLN modulation
|
||||
2. Cross-attention with AdaLN modulation
|
||||
3. MLP with AdaLN modulation
|
||||
|
||||
Each component uses skip connections and layer normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_dim: int,
|
||||
context_dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.x_dim = x_dim
|
||||
self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||
self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||
self.cross_attn = Attention(
|
||||
x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
|
||||
self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||
self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
if self.use_adaln_lora:
|
||||
self.adaln_modulation_self_attn = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
self.adaln_modulation_cross_attn = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
self.adaln_modulation_mlp = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
else:
|
||||
self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
||||
self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
||||
self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_T_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
|
||||
if self.use_adaln_lora:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
|
||||
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
|
||||
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
|
||||
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
else:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
|
||||
emb_B_T_D
|
||||
).chunk(3, dim=-1)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
|
||||
emb_B_T_D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
|
||||
|
||||
# Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting
|
||||
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
|
||||
shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||
|
||||
shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
|
||||
scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
|
||||
gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
|
||||
|
||||
B, T, H, W, D = x_B_T_H_W_D.shape
|
||||
|
||||
def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
|
||||
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
self.layer_norm_self_attn,
|
||||
scale_self_attn_B_T_1_1_D,
|
||||
shift_self_attn_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = rearrange(
|
||||
self.self_attn(
|
||||
# normalized_x_B_T_HW_D,
|
||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T,
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
||||
|
||||
def _x_fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
layer_norm_cross_attn: Callable,
|
||||
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
_normalized_x_B_T_H_W_D = _fn(
|
||||
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
||||
)
|
||||
_result_B_T_H_W_D = rearrange(
|
||||
self.cross_attn(
|
||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T,
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
return _result_B_T_H_W_D
|
||||
|
||||
result_B_T_H_W_D = _x_fn(
|
||||
x_B_T_H_W_D,
|
||||
self.layer_norm_cross_attn,
|
||||
scale_cross_attn_B_T_1_1_D,
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
self.layer_norm_mlp,
|
||||
scale_mlp_B_T_1_1_D,
|
||||
shift_mlp_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
||||
return x_B_T_H_W_D
|
||||
|
||||
|
||||
class MiniTrainDIT(nn.Module):
|
||||
"""
|
||||
A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1)
|
||||
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
||||
|
||||
Args:
|
||||
max_img_h (int): Maximum height of the input images.
|
||||
max_img_w (int): Maximum width of the input images.
|
||||
max_frames (int): Maximum number of frames in the video sequence.
|
||||
in_channels (int): Number of input channels (e.g., RGB channels for color images).
|
||||
out_channels (int): Number of output channels.
|
||||
patch_spatial (tuple): Spatial resolution of patches for input processing.
|
||||
patch_temporal (int): Temporal resolution of patches for input processing.
|
||||
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
|
||||
model_channels (int): Base number of channels used throughout the model.
|
||||
num_blocks (int): Number of transformer blocks.
|
||||
num_heads (int): Number of heads in the multi-head attention layers.
|
||||
mlp_ratio (float): Expansion ratio for MLP blocks.
|
||||
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
|
||||
pos_emb_cls (str): Type of positional embeddings.
|
||||
pos_emb_learnable (bool): Whether positional embeddings are learnable.
|
||||
pos_emb_interpolation (str): Method for interpolating positional embeddings.
|
||||
min_fps (int): Minimum frames per second.
|
||||
max_fps (int): Maximum frames per second.
|
||||
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
|
||||
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
|
||||
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
|
||||
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
|
||||
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
|
||||
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
|
||||
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
|
||||
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
|
||||
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_img_h: int,
|
||||
max_img_w: int,
|
||||
max_frames: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
patch_spatial: int, # tuple,
|
||||
patch_temporal: int,
|
||||
concat_padding_mask: bool = True,
|
||||
# attention settings
|
||||
model_channels: int = 768,
|
||||
num_blocks: int = 10,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: float = 4.0,
|
||||
# cross attention settings
|
||||
crossattn_emb_channels: int = 1024,
|
||||
# positional embedding settings
|
||||
pos_emb_cls: str = "sincos",
|
||||
pos_emb_learnable: bool = False,
|
||||
pos_emb_interpolation: str = "crop",
|
||||
min_fps: int = 1,
|
||||
max_fps: int = 30,
|
||||
use_adaln_lora: bool = False,
|
||||
adaln_lora_dim: int = 256,
|
||||
rope_h_extrapolation_ratio: float = 1.0,
|
||||
rope_w_extrapolation_ratio: float = 1.0,
|
||||
rope_t_extrapolation_ratio: float = 1.0,
|
||||
extra_per_block_abs_pos_emb: bool = False,
|
||||
extra_h_extrapolation_ratio: float = 1.0,
|
||||
extra_w_extrapolation_ratio: float = 1.0,
|
||||
extra_t_extrapolation_ratio: float = 1.0,
|
||||
rope_enable_fps_modulation: bool = True,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.max_img_h = max_img_h
|
||||
self.max_img_w = max_img_w
|
||||
self.max_frames = max_frames
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.patch_spatial = patch_spatial
|
||||
self.patch_temporal = patch_temporal
|
||||
self.num_heads = num_heads
|
||||
self.num_blocks = num_blocks
|
||||
self.model_channels = model_channels
|
||||
self.concat_padding_mask = concat_padding_mask
|
||||
# positional embedding settings
|
||||
self.pos_emb_cls = pos_emb_cls
|
||||
self.pos_emb_learnable = pos_emb_learnable
|
||||
self.pos_emb_interpolation = pos_emb_interpolation
|
||||
self.min_fps = min_fps
|
||||
self.max_fps = max_fps
|
||||
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
|
||||
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
|
||||
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
|
||||
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
|
||||
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
|
||||
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
|
||||
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
|
||||
self.rope_enable_fps_modulation = rope_enable_fps_modulation
|
||||
|
||||
self.build_pos_embed(device=device, dtype=dtype)
|
||||
self.use_adaln_lora = use_adaln_lora
|
||||
self.adaln_lora_dim = adaln_lora_dim
|
||||
self.t_embedder = nn.Sequential(
|
||||
Timesteps(model_channels),
|
||||
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),
|
||||
)
|
||||
|
||||
in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
||||
self.x_embedder = PatchEmbed(
|
||||
spatial_patch_size=patch_spatial,
|
||||
temporal_patch_size=patch_temporal,
|
||||
in_channels=in_channels,
|
||||
out_channels=model_channels,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
x_dim=model_channels,
|
||||
context_dim=crossattn_emb_channels,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
use_adaln_lora=use_adaln_lora,
|
||||
adaln_lora_dim=adaln_lora_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer(
|
||||
hidden_size=self.model_channels,
|
||||
spatial_patch_size=self.patch_spatial,
|
||||
temporal_patch_size=self.patch_temporal,
|
||||
out_channels=self.out_channels,
|
||||
use_adaln_lora=self.use_adaln_lora,
|
||||
adaln_lora_dim=self.adaln_lora_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)
|
||||
|
||||
def build_pos_embed(self, device=None, dtype=None) -> None:
|
||||
if self.pos_emb_cls == "rope3d":
|
||||
cls_type = VideoRopePosition3DEmb
|
||||
else:
|
||||
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
||||
|
||||
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
||||
kwargs = dict(
|
||||
model_channels=self.model_channels,
|
||||
len_h=self.max_img_h // self.patch_spatial,
|
||||
len_w=self.max_img_w // self.patch_spatial,
|
||||
len_t=self.max_frames // self.patch_temporal,
|
||||
max_fps=self.max_fps,
|
||||
min_fps=self.min_fps,
|
||||
is_learnable=self.pos_emb_learnable,
|
||||
interpolation=self.pos_emb_interpolation,
|
||||
head_dim=self.model_channels // self.num_heads,
|
||||
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
|
||||
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
|
||||
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
|
||||
enable_fps_modulation=self.rope_enable_fps_modulation,
|
||||
device=device,
|
||||
)
|
||||
self.pos_embedder = cls_type(
|
||||
**kwargs, # type: ignore
|
||||
)
|
||||
|
||||
if self.extra_per_block_abs_pos_emb:
|
||||
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
|
||||
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
||||
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
||||
kwargs["device"] = device
|
||||
kwargs["dtype"] = dtype
|
||||
self.extra_pos_embedder = LearnablePosEmbAxis(
|
||||
**kwargs, # type: ignore
|
||||
)
|
||||
|
||||
def prepare_embedded_sequence(
|
||||
self,
|
||||
x_B_C_T_H_W: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
|
||||
|
||||
Args:
|
||||
x_B_C_T_H_W (torch.Tensor): video
|
||||
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
|
||||
If None, a default value (`self.base_fps`) will be used.
|
||||
padding_mask (Optional[torch.Tensor]): current it is not used
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
|
||||
- An optional positional embedding tensor, returned only if the positional embedding class
|
||||
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
|
||||
|
||||
Notes:
|
||||
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
|
||||
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
|
||||
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
|
||||
the `self.pos_embedder` with the shape [T, H, W].
|
||||
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
|
||||
`self.pos_embedder` with the fps tensor.
|
||||
- Otherwise, the positional embeddings are generated without considering fps.
|
||||
"""
|
||||
if self.concat_padding_mask:
|
||||
if padding_mask is None:
|
||||
padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
|
||||
else:
|
||||
padding_mask = transforms.functional.resize(
|
||||
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
x_B_C_T_H_W = torch.cat(
|
||||
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
|
||||
)
|
||||
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
||||
|
||||
if self.extra_per_block_abs_pos_emb:
|
||||
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
|
||||
else:
|
||||
extra_pos_emb = None
|
||||
|
||||
if "rope" in self.pos_emb_cls.lower():
|
||||
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
||||
|
||||
return x_B_T_H_W_D, None, extra_pos_emb
|
||||
|
||||
def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
|
||||
x_B_C_Tt_Hp_Wp = rearrange(
|
||||
x_B_T_H_W_M,
|
||||
"B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
|
||||
p1=self.patch_spatial,
|
||||
p2=self.patch_spatial,
|
||||
t=self.patch_temporal,
|
||||
)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
x_B_C_T_H_W = x
|
||||
timesteps_B_T = timesteps
|
||||
crossattn_emb = context
|
||||
"""
|
||||
Args:
|
||||
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
||||
timesteps: (B, ) tensor of timesteps
|
||||
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
||||
"""
|
||||
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
|
||||
x_B_C_T_H_W,
|
||||
fps=fps,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
|
||||
if timesteps_B_T.ndim == 1:
|
||||
timesteps_B_T = timesteps_B_T.unsqueeze(1)
|
||||
t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))
|
||||
t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
|
||||
|
||||
# for logging purpose
|
||||
affline_scale_log_info = {}
|
||||
affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach()
|
||||
self.affline_scale_log_info = affline_scale_log_info
|
||||
self.affline_emb = t_embedding_B_T_D
|
||||
self.crossattn_emb = crossattn_emb
|
||||
|
||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||
assert (
|
||||
x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||
), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}"
|
||||
|
||||
block_kwargs = {
|
||||
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
||||
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
}
|
||||
for block in self.blocks:
|
||||
x_B_T_H_W_D = block(
|
||||
x_B_T_H_W_D,
|
||||
t_embedding_B_T_D,
|
||||
crossattn_emb,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
||||
return x_B_C_Tt_Hp_Wp
|
@ -121,6 +121,9 @@ class ControlNetFlux(Flux):
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
|
||||
@ -174,7 +177,7 @@ class ControlNetFlux(Flux):
|
||||
out["output"] = out_output[:self.main_model_single]
|
||||
return out
|
||||
|
||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||
def forward(self, x, timesteps, context, y=None, guidance=None, hint=None, **kwargs):
|
||||
patch_size = 2
|
||||
if self.latent_input:
|
||||
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||
|
@ -101,6 +101,10 @@ class Flux(nn.Module):
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@ -155,6 +159,9 @@ class Flux(nn.Module):
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
if img.dtype == torch.float16:
|
||||
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
@ -188,7 +195,7 @@ class Flux(nn.Module):
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, y=None, guidance=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
@ -261,8 +261,8 @@ class CrossAttention(nn.Module):
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
|
||||
self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
@ -753,7 +753,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
for p in patch:
|
||||
n = p(n, extra_options)
|
||||
|
||||
x += n
|
||||
x = n + x
|
||||
if "middle_patch" in transformer_patches:
|
||||
patch = transformer_patches["middle_patch"]
|
||||
for p in patch:
|
||||
@ -793,12 +793,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
for p in patch:
|
||||
n = p(n, extra_options)
|
||||
|
||||
x += n
|
||||
x = n + x
|
||||
if self.is_res:
|
||||
x_skip = x
|
||||
x = self.ff(self.norm3(x))
|
||||
if self.is_res:
|
||||
x += x_skip
|
||||
x = x_skip + x
|
||||
|
||||
return x
|
||||
|
||||
|
@ -34,6 +34,7 @@ import comfy.ldm.flux.model
|
||||
import comfy.ldm.lightricks.model
|
||||
import comfy.ldm.hunyuan_video.model
|
||||
import comfy.ldm.cosmos.model
|
||||
import comfy.ldm.cosmos.predict2
|
||||
import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
@ -48,6 +49,7 @@ import comfy.ops
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
import comfy.latent_formats
|
||||
import comfy.model_sampling
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
@ -63,38 +65,39 @@ class ModelType(Enum):
|
||||
V_PREDICTION_CONTINUOUS = 7
|
||||
FLUX = 8
|
||||
IMG_TO_IMG = 9
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
|
||||
FLOW_COSMOS = 10
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
s = ModelSamplingDiscrete
|
||||
s = comfy.model_sampling.ModelSamplingDiscrete
|
||||
|
||||
if model_type == ModelType.EPS:
|
||||
c = EPS
|
||||
c = comfy.model_sampling.EPS
|
||||
elif model_type == ModelType.V_PREDICTION:
|
||||
c = V_PREDICTION
|
||||
c = comfy.model_sampling.V_PREDICTION
|
||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousEDM
|
||||
c = comfy.model_sampling.V_PREDICTION
|
||||
s = comfy.model_sampling.ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.FLOW:
|
||||
c = comfy.model_sampling.CONST
|
||||
s = comfy.model_sampling.ModelSamplingDiscreteFlow
|
||||
elif model_type == ModelType.STABLE_CASCADE:
|
||||
c = EPS
|
||||
s = StableCascadeSampling
|
||||
c = comfy.model_sampling.EPS
|
||||
s = comfy.model_sampling.StableCascadeSampling
|
||||
elif model_type == ModelType.EDM:
|
||||
c = EDM
|
||||
s = ModelSamplingContinuousEDM
|
||||
c = comfy.model_sampling.EDM
|
||||
s = comfy.model_sampling.ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousV
|
||||
c = comfy.model_sampling.V_PREDICTION
|
||||
s = comfy.model_sampling.ModelSamplingContinuousV
|
||||
elif model_type == ModelType.FLUX:
|
||||
c = comfy.model_sampling.CONST
|
||||
s = comfy.model_sampling.ModelSamplingFlux
|
||||
elif model_type == ModelType.IMG_TO_IMG:
|
||||
c = comfy.model_sampling.IMG_TO_IMG
|
||||
elif model_type == ModelType.FLOW_COSMOS:
|
||||
c = comfy.model_sampling.COSMOS_RFLOW
|
||||
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@ -102,6 +105,13 @@ def model_sampling(model_config, model_type):
|
||||
return ModelSampling(model_config)
|
||||
|
||||
|
||||
def convert_tensor(extra, dtype):
|
||||
if hasattr(extra, "dtype"):
|
||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||
extra = extra.to(dtype)
|
||||
return extra
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
||||
super().__init__()
|
||||
@ -165,13 +175,13 @@ class BaseModel(torch.nn.Module):
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra = kwargs[o]
|
||||
|
||||
if hasattr(extra, "dtype"):
|
||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||
extra = extra.to(dtype)
|
||||
if isinstance(extra, list):
|
||||
extra = convert_tensor(extra, dtype)
|
||||
elif isinstance(extra, list):
|
||||
ex = []
|
||||
for ext in extra:
|
||||
ex.append(ext.to(dtype))
|
||||
ex.append(convert_tensor(ext, dtype))
|
||||
extra = ex
|
||||
extra_conds[o] = extra
|
||||
|
||||
@ -991,6 +1001,43 @@ class CosmosVideo(BaseModel):
|
||||
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
|
||||
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
|
||||
|
||||
class CosmosPredict2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW_COSMOS, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.predict2.MiniTrainDIT)
|
||||
self.image_to_video = image_to_video
|
||||
if self.image_to_video:
|
||||
self.concat_keys = ("mask_inverted",)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if denoise_mask is not None:
|
||||
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||
|
||||
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||
if denoise_mask is None:
|
||||
return timestep
|
||||
condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True)
|
||||
c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1
|
||||
out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4])
|
||||
return out
|
||||
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1))
|
||||
sigma_noise_augmentation = 0 #TODO
|
||||
if sigma_noise_augmentation != 0:
|
||||
latent_image = latent_image + noise
|
||||
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
|
||||
sigma = (sigma / (sigma + 1))
|
||||
return latent_image / (1.0 - sigma)
|
||||
|
||||
class Lumina2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
|
||||
|
@ -407,6 +407,58 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["text_emb_dim"] = 2048
|
||||
return dit_config
|
||||
|
||||
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "cosmos_predict2"
|
||||
dit_config["max_img_h"] = 240
|
||||
dit_config["max_img_w"] = 240
|
||||
dit_config["max_frames"] = 128
|
||||
concat_padding_mask = True
|
||||
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["patch_spatial"] = 2
|
||||
dit_config["patch_temporal"] = 1
|
||||
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["concat_padding_mask"] = concat_padding_mask
|
||||
dit_config["crossattn_emb_channels"] = 1024
|
||||
dit_config["pos_emb_cls"] = "rope3d"
|
||||
dit_config["pos_emb_learnable"] = True
|
||||
dit_config["pos_emb_interpolation"] = "crop"
|
||||
dit_config["min_fps"] = 1
|
||||
dit_config["max_fps"] = 30
|
||||
|
||||
dit_config["use_adaln_lora"] = True
|
||||
dit_config["adaln_lora_dim"] = 256
|
||||
if dit_config["model_channels"] == 2048:
|
||||
dit_config["num_blocks"] = 28
|
||||
dit_config["num_heads"] = 16
|
||||
elif dit_config["model_channels"] == 5120:
|
||||
dit_config["num_blocks"] = 36
|
||||
dit_config["num_heads"] = 40
|
||||
|
||||
if dit_config["in_channels"] == 16:
|
||||
dit_config["extra_per_block_abs_pos_emb"] = False
|
||||
dit_config["rope_h_extrapolation_ratio"] = 4.0
|
||||
dit_config["rope_w_extrapolation_ratio"] = 4.0
|
||||
dit_config["rope_t_extrapolation_ratio"] = 1.0
|
||||
elif dit_config["in_channels"] == 17: # img to video
|
||||
if dit_config["model_channels"] == 2048:
|
||||
dit_config["extra_per_block_abs_pos_emb"] = False
|
||||
dit_config["rope_h_extrapolation_ratio"] = 3.0
|
||||
dit_config["rope_w_extrapolation_ratio"] = 3.0
|
||||
dit_config["rope_t_extrapolation_ratio"] = 1.0
|
||||
elif dit_config["model_channels"] == 5120:
|
||||
dit_config["rope_h_extrapolation_ratio"] = 2.0
|
||||
dit_config["rope_w_extrapolation_ratio"] = 2.0
|
||||
dit_config["rope_t_extrapolation_ratio"] = 0.8333333333333334
|
||||
|
||||
dit_config["extra_h_extrapolation_ratio"] = 1.0
|
||||
dit_config["extra_w_extrapolation_ratio"] = 1.0
|
||||
dit_config["extra_t_extrapolation_ratio"] = 1.0
|
||||
dit_config["rope_enable_fps_modulation"] = False
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
|
@ -319,6 +319,7 @@ except:
|
||||
pass
|
||||
|
||||
|
||||
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
||||
try:
|
||||
if is_amd():
|
||||
try:
|
||||
@ -329,9 +330,13 @@ try:
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||
SUPPORT_FP8_OPS = True
|
||||
|
||||
except:
|
||||
pass
|
||||
|
||||
@ -352,7 +357,7 @@ except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5:
|
||||
if torch_version_numeric >= (2, 5):
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||
except:
|
||||
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
|
||||
@ -1075,7 +1080,7 @@ def pytorch_attention_flash_attention():
|
||||
global ENABLE_PYTORCH_ATTENTION
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
#TODO: more reliable way of checking for flash attention?
|
||||
if is_nvidia(): #pytorch flash attention only works on Nvidia
|
||||
if is_nvidia():
|
||||
return True
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
@ -1091,7 +1096,7 @@ def force_upcast_attention_dtype():
|
||||
upcast = args.force_upcast_attention
|
||||
|
||||
macos_version = mac_version()
|
||||
if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS
|
||||
if macos_version is not None and ((14, 5) <= macos_version): # black image bug on recent versions of macOS, I don't think it's ever getting fixed
|
||||
upcast = True
|
||||
|
||||
if upcast:
|
||||
@ -1290,7 +1295,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
def supports_fp8_compute(device=None):
|
||||
if args.supports_fp8_compute:
|
||||
if SUPPORT_FP8_OPS:
|
||||
return True
|
||||
|
||||
if not is_nvidia():
|
||||
@ -1304,11 +1309,11 @@ def supports_fp8_compute(device=None):
|
||||
if props.minor < 9:
|
||||
return False
|
||||
|
||||
if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3):
|
||||
if torch_version_numeric < (2, 3):
|
||||
return False
|
||||
|
||||
if WINDOWS:
|
||||
if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4):
|
||||
if torch_version_numeric < (2, 4):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
@ -17,23 +17,26 @@
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Callable
|
||||
import torch
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
import collections
|
||||
import math
|
||||
import uuid
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
import comfy.float
|
||||
import comfy.model_management
|
||||
import comfy.lora
|
||||
import comfy.hooks
|
||||
import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
||||
import comfy.utils
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
|
@ -77,6 +77,25 @@ class IMG_TO_IMG(X0):
|
||||
def calculate_input(self, sigma, noise):
|
||||
return noise
|
||||
|
||||
class COSMOS_RFLOW:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = (sigma / (sigma + 1))
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
return noise * (1.0 - sigma)
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = (sigma / (sigma + 1))
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input * (1.0 - sigma) - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
noise = noise * sigma
|
||||
noise += latent_image
|
||||
return noise
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
return latent
|
||||
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def __init__(self, model_config=None, zsnr=None):
|
||||
@ -350,3 +369,15 @@ class ModelSamplingFlux(torch.nn.Module):
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
return flux_time_shift(self.shift, 1.0, 1.0 - percent)
|
||||
|
||||
|
||||
class ModelSamplingCosmosRFlow(ModelSamplingContinuousEDM):
|
||||
def timestep(self, sigma):
|
||||
return sigma / (sigma + 1)
|
||||
|
||||
def sigma(self, timestep):
|
||||
sigma_max = self.sigma_max
|
||||
if timestep >= (sigma_max / (sigma_max + 1)):
|
||||
return sigma_max
|
||||
|
||||
return timestep / (1 - timestep)
|
||||
|
23
comfy/sd.py
23
comfy/sd.py
@ -1081,7 +1081,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
|
||||
def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
"""
|
||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||
|
||||
Args:
|
||||
sd (dict): State dictionary containing model weights and configuration
|
||||
model_options (dict, optional): Additional options for model loading. Supports:
|
||||
- dtype: Override model data type
|
||||
- custom_operations: Custom model operations
|
||||
- fp8_optimizations: Enable FP8 optimizations
|
||||
|
||||
Returns:
|
||||
ModelPatcher: A wrapped model instance that handles device management and weight loading.
|
||||
Returns None if the model configuration cannot be detected.
|
||||
|
||||
The function:
|
||||
1. Detects and handles different model formats (regular, diffusers, mmdit)
|
||||
2. Configures model dtype based on parameters and device capabilities
|
||||
3. Handles weight conversion and device placement
|
||||
4. Manages model optimization settings
|
||||
5. Loads weights and returns a device-managed model instance
|
||||
"""
|
||||
dtype = model_options.get("dtype", None)
|
||||
|
||||
#Allow loading unets from checkpoint files
|
||||
|
@ -462,7 +462,7 @@ class SDTokenizer:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
|
||||
self.min_length = min_length
|
||||
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
|
||||
self.end_token = None
|
||||
self.min_padding = min_padding
|
||||
|
||||
|
@ -908,6 +908,48 @@ class CosmosI2V(CosmosT2V):
|
||||
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class CosmosT2IPredict2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "cosmos_predict2",
|
||||
"in_channels": 16,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"sigma_data": 1.0,
|
||||
"sigma_max": 80.0,
|
||||
"sigma_min": 0.002,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Wan21
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.CosmosPredict2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
|
||||
|
||||
class CosmosI2VPredict2(CosmosT2IPredict2):
|
||||
unet_config = {
|
||||
"image_model": "cosmos_predict2",
|
||||
"in_channels": 17,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.CosmosPredict2(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class Lumina2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "lumina2",
|
||||
@ -1139,6 +1181,6 @@ class ACEStep(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .base import WeightAdapterBase
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
||||
from .lora import LoRAAdapter
|
||||
from .loha import LoHaAdapter
|
||||
from .lokr import LoKrAdapter
|
||||
@ -15,3 +15,9 @@ adapters: list[type[WeightAdapterBase]] = [
|
||||
OFTAdapter,
|
||||
BOFTAdapter,
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
"WeightAdapterBase",
|
||||
"WeightAdapterTrainBase",
|
||||
"adapters"
|
||||
] + [a.__name__ for a in adapters]
|
||||
|
@ -12,12 +12,20 @@ class WeightAdapterBase:
|
||||
weights: list[torch.Tensor]
|
||||
|
||||
@classmethod
|
||||
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
|
||||
def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]:
|
||||
raise NotImplementedError
|
||||
|
||||
def to_train(self) -> "WeightAdapterTrainBase":
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def create_train(cls, weight, *args) -> "WeightAdapterTrainBase":
|
||||
"""
|
||||
weight: The original weight tensor to be modified.
|
||||
*args: Additional arguments for configuration, such as rank, alpha etc.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
@ -33,10 +41,22 @@ class WeightAdapterBase:
|
||||
|
||||
|
||||
class WeightAdapterTrainBase(nn.Module):
|
||||
# We follow the scheme of PR #7032
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# [TODO] Collaborate with LoRA training PR #7032
|
||||
def __call__(self, w):
|
||||
"""
|
||||
w: The original weight tensor to be modified.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def passive_memory_usage(self):
|
||||
raise NotImplementedError("passive_memory_usage is not implemented")
|
||||
|
||||
def move_to(self, device):
|
||||
self.to(device)
|
||||
return self.passive_memory_usage()
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||
@ -102,3 +122,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
||||
padded_tensor[new_slices] = tensor[orig_slices]
|
||||
|
||||
return padded_tensor
|
||||
|
||||
|
||||
def tucker_weight_from_conv(up, down, mid):
|
||||
up = up.reshape(up.size(0), up.size(1))
|
||||
down = down.reshape(down.size(0), down.size(1))
|
||||
return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down)
|
||||
|
||||
|
||||
def tucker_weight(wa, wb, t):
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
|
||||
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
|
||||
|
@ -3,7 +3,56 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
|
||||
from .base import (
|
||||
WeightAdapterBase,
|
||||
WeightAdapterTrainBase,
|
||||
weight_decompose,
|
||||
pad_tensor_to_shape,
|
||||
tucker_weight_from_conv,
|
||||
)
|
||||
|
||||
|
||||
class LoraDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
mat1, mat2, alpha, mid, dora_scale, reshape = weights
|
||||
out_dim, rank = mat1.shape[0], mat1.shape[1]
|
||||
rank, in_dim = mat2.shape[0], mat2.shape[1]
|
||||
if mid is not None:
|
||||
convdim = mid.ndim - 2
|
||||
layer = (
|
||||
torch.nn.Conv1d,
|
||||
torch.nn.Conv2d,
|
||||
torch.nn.Conv3d
|
||||
)[convdim]
|
||||
else:
|
||||
layer = torch.nn.Linear
|
||||
self.lora_up = layer(rank, out_dim, bias=False)
|
||||
self.lora_down = layer(in_dim, rank, bias=False)
|
||||
self.lora_up.weight.data.copy_(mat1)
|
||||
self.lora_down.weight.data.copy_(mat2)
|
||||
if mid is not None:
|
||||
self.lora_mid = layer(mid, rank, bias=False)
|
||||
self.lora_mid.weight.data.copy_(mid)
|
||||
else:
|
||||
self.lora_mid = None
|
||||
self.rank = rank
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||
|
||||
def __call__(self, w):
|
||||
org_dtype = w.dtype
|
||||
if self.lora_mid is None:
|
||||
diff = self.lora_up.weight @ self.lora_down.weight
|
||||
else:
|
||||
diff = tucker_weight_from_conv(
|
||||
self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight
|
||||
)
|
||||
scale = self.alpha / self.rank
|
||||
weight = w + scale * diff.reshape(w.shape)
|
||||
return weight.to(org_dtype)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
|
||||
|
||||
class LoRAAdapter(WeightAdapterBase):
|
||||
@ -13,6 +62,21 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
in_dim = weight.shape[1:].numel()
|
||||
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
|
||||
torch.nn.init.constant_(mat2, 0.0)
|
||||
return LoraDiff(
|
||||
(mat1, mat2, alpha, None, None, None)
|
||||
)
|
||||
|
||||
def to_train(self):
|
||||
return LoraDiff(self.weights)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
|
@ -346,20 +346,6 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, aspect_ratio: str):
|
||||
try:
|
||||
validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=cls.MINIMUM_RATIO,
|
||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
||||
)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
return True
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
@ -380,6 +366,13 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
unique_id: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
aspect_ratio = validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=self.MINIMUM_RATIO,
|
||||
maximum_ratio=self.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=self.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
|
||||
)
|
||||
if input_image is None:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
operation = SynchronousOperation(
|
||||
@ -395,13 +388,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
||||
guidance=round(guidance, 1),
|
||||
steps=steps,
|
||||
seed=seed,
|
||||
aspect_ratio=validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=self.MINIMUM_RATIO,
|
||||
maximum_ratio=self.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=self.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
|
||||
),
|
||||
aspect_ratio=aspect_ratio,
|
||||
input_image=(
|
||||
input_image
|
||||
if input_image is None
|
||||
|
@ -324,7 +324,7 @@ class IdeogramV1(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram/v1"
|
||||
CATEGORY = "api node/image/Ideogram"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
@ -483,7 +483,7 @@ class IdeogramV2(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram/v2"
|
||||
CATEGORY = "api node/image/Ideogram"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
@ -649,7 +649,7 @@ class IdeogramV3(ComfyNodeABC):
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram/v3"
|
||||
CATEGORY = "api node/image/Ideogram"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
|
97
comfy_config/config_parser.py
Normal file
97
comfy_config/config_parser.py
Normal file
@ -0,0 +1,97 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic_settings import PydanticBaseSettingsSource, TomlConfigSettingsSource
|
||||
|
||||
from comfy_config.types import (
|
||||
ComfyConfig,
|
||||
ProjectConfig,
|
||||
PyProjectConfig,
|
||||
PyProjectSettings
|
||||
)
|
||||
|
||||
"""
|
||||
Extract configuration from a custom node directory's pyproject.toml file or a Python file.
|
||||
|
||||
This function reads and parses the pyproject.toml file in the specified directory
|
||||
to extract project and ComfyUI-specific configuration information. If no
|
||||
pyproject.toml file is found, it creates a minimal configuration using the
|
||||
folder name as the project name. If a Python file is provided, it uses the
|
||||
file name (without extension) as the project name.
|
||||
|
||||
Args:
|
||||
path (str): Path to the directory containing the pyproject.toml file, or
|
||||
path to a .py file. If pyproject.toml doesn't exist in a directory,
|
||||
the folder name will be used as the default project name. If a .py
|
||||
file is provided, the filename (without .py extension) will be used
|
||||
as the project name.
|
||||
|
||||
Returns:
|
||||
Optional[PyProjectConfig]: A PyProjectConfig object containing:
|
||||
- project: Basic project information (name, version, dependencies, etc.)
|
||||
- tool_comfy: ComfyUI-specific configuration (publisher_id, models, etc.)
|
||||
Returns None if configuration extraction fails or if the provided file
|
||||
is not a Python file.
|
||||
|
||||
Notes:
|
||||
- If pyproject.toml is missing in a directory, creates a default config with folder name
|
||||
- If a .py file is provided, creates a default config with filename (without extension)
|
||||
- Returns None for non-Python files
|
||||
|
||||
Example:
|
||||
>>> from comfy_config import config_parser
|
||||
>>> # For directory
|
||||
>>> custom_node_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
>>> project_config = config_parser.extract_node_configuration(custom_node_dir)
|
||||
>>> print(project_config.project.name) # "my_custom_node" or name from pyproject.toml
|
||||
>>>
|
||||
>>> # For single-file Python node file
|
||||
>>> py_file_path = os.path.realpath(__file__) # "/path/to/my_node.py"
|
||||
>>> project_config = config_parser.extract_node_configuration(py_file_path)
|
||||
>>> print(project_config.project.name) # "my_node"
|
||||
"""
|
||||
def extract_node_configuration(path) -> Optional[PyProjectConfig]:
|
||||
if os.path.isfile(path):
|
||||
file_path = Path(path)
|
||||
|
||||
if file_path.suffix.lower() != '.py':
|
||||
return None
|
||||
|
||||
project_name = file_path.stem
|
||||
project = ProjectConfig(name=project_name)
|
||||
comfy = ComfyConfig()
|
||||
return PyProjectConfig(project=project, tool_comfy=comfy)
|
||||
|
||||
folder_name = os.path.basename(path)
|
||||
toml_path = Path(path) / "pyproject.toml"
|
||||
|
||||
if not toml_path.exists():
|
||||
project = ProjectConfig(name=folder_name)
|
||||
comfy = ComfyConfig()
|
||||
return PyProjectConfig(project=project, tool_comfy=comfy)
|
||||
|
||||
raw_settings = load_pyproject_settings(toml_path)
|
||||
|
||||
project_data = raw_settings.project
|
||||
|
||||
tool_data = raw_settings.tool
|
||||
comfy_data = tool_data.get("comfy", {}) if tool_data else {}
|
||||
|
||||
return PyProjectConfig(project=project_data, tool_comfy=comfy_data)
|
||||
|
||||
|
||||
def load_pyproject_settings(toml_path: Path) -> PyProjectSettings:
|
||||
class PyProjectLoader(PyProjectSettings):
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls,
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
):
|
||||
return (TomlConfigSettingsSource(settings_cls, toml_path),)
|
||||
|
||||
return PyProjectLoader()
|
93
comfy_config/types.py
Normal file
93
comfy_config/types.py
Normal file
@ -0,0 +1,93 @@
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from typing import List, Optional
|
||||
|
||||
# IMPORTANT: The type definitions specified in pyproject.toml for custom nodes
|
||||
# must remain synchronized with the corresponding files in the https://github.com/Comfy-Org/comfy-cli/blob/main/comfy_cli/registry/types.py.
|
||||
# Any changes to one must be reflected in the other to maintain consistency.
|
||||
|
||||
class NodeVersion(BaseModel):
|
||||
changelog: str
|
||||
dependencies: List[str]
|
||||
deprecated: bool
|
||||
id: str
|
||||
version: str
|
||||
download_url: str
|
||||
|
||||
|
||||
class Node(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
author: Optional[str] = None
|
||||
license: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
repository: Optional[str] = None
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
latest_version: Optional[NodeVersion] = None
|
||||
|
||||
|
||||
class PublishNodeVersionResponse(BaseModel):
|
||||
node_version: NodeVersion
|
||||
signedUrl: str
|
||||
|
||||
|
||||
class URLs(BaseModel):
|
||||
homepage: str = Field(default="", alias="Homepage")
|
||||
documentation: str = Field(default="", alias="Documentation")
|
||||
repository: str = Field(default="", alias="Repository")
|
||||
issues: str = Field(default="", alias="Issues")
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
location: str
|
||||
model_url: str
|
||||
|
||||
|
||||
class ComfyConfig(BaseModel):
|
||||
publisher_id: str = Field(default="", alias="PublisherId")
|
||||
display_name: str = Field(default="", alias="DisplayName")
|
||||
icon: str = Field(default="", alias="Icon")
|
||||
models: List[Model] = Field(default_factory=list, alias="Models")
|
||||
includes: List[str] = Field(default_factory=list)
|
||||
web: Optional[str] = None
|
||||
|
||||
|
||||
class License(BaseModel):
|
||||
file: str = ""
|
||||
text: str = ""
|
||||
|
||||
|
||||
class ProjectConfig(BaseModel):
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
version: str = "1.0.0"
|
||||
requires_python: str = Field(default=">= 3.9", alias="requires-python")
|
||||
dependencies: List[str] = Field(default_factory=list)
|
||||
license: License = Field(default_factory=License)
|
||||
urls: URLs = Field(default_factory=URLs)
|
||||
|
||||
@field_validator('license', mode='before')
|
||||
@classmethod
|
||||
def validate_license(cls, v):
|
||||
if isinstance(v, str):
|
||||
return License(text=v)
|
||||
elif isinstance(v, dict):
|
||||
return License(**v)
|
||||
elif isinstance(v, License):
|
||||
return v
|
||||
else:
|
||||
return License()
|
||||
|
||||
|
||||
class PyProjectConfig(BaseModel):
|
||||
project: ProjectConfig = Field(default_factory=ProjectConfig)
|
||||
tool_comfy: ComfyConfig = Field(default_factory=ComfyConfig)
|
||||
|
||||
|
||||
class PyProjectSettings(BaseSettings):
|
||||
project: dict = Field(default_factory=dict)
|
||||
|
||||
tool: dict = Field(default_factory=dict)
|
||||
|
||||
model_config = SettingsConfigDict(extra='allow')
|
@ -2,6 +2,7 @@ import nodes
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.latent_formats
|
||||
|
||||
|
||||
class EmptyCosmosLatentVideo:
|
||||
@ -75,8 +76,53 @@ class CosmosImageToVideoLatent:
|
||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||
return (out_latent,)
|
||||
|
||||
class CosmosPredict2ImageToVideoLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"start_image": ("IMAGE", ),
|
||||
"end_image": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/inpaint"
|
||||
|
||||
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
|
||||
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is None and end_image is None:
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (out_latent,)
|
||||
|
||||
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||
|
||||
if start_image is not None:
|
||||
latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
|
||||
latent[:, :, :latent_temp.shape[-3]] = latent_temp
|
||||
mask[:, :, :latent_temp.shape[-3]] *= 0.0
|
||||
|
||||
if end_image is not None:
|
||||
latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
|
||||
latent[:, :, -latent_temp.shape[-3]:] = latent_temp
|
||||
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
|
||||
|
||||
out_latent = {}
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
|
||||
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||
return (out_latent,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
|
||||
"CosmosImageToVideoLatent": CosmosImageToVideoLatent,
|
||||
"CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent,
|
||||
}
|
||||
|
@ -16,7 +16,8 @@ from inspect import cleandoc
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
from comfy.comfy_types import FileLocator
|
||||
from comfy.comfy_types import FileLocator, IO
|
||||
from server import PromptServer
|
||||
|
||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||
|
||||
@ -491,6 +492,37 @@ class SaveSVGNode:
|
||||
counter += 1
|
||||
return { "ui": { "images": results } }
|
||||
|
||||
class GetImageSize:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.INT, IO.INT, IO.INT)
|
||||
RETURN_NAMES = ("width", "height", "batch_size")
|
||||
FUNCTION = "get_size"
|
||||
|
||||
CATEGORY = "image"
|
||||
DESCRIPTION = """Returns width and height of the image, and passes it through unchanged."""
|
||||
|
||||
def get_size(self, image, unique_id=None) -> tuple[int, int]:
|
||||
height = image.shape[1]
|
||||
width = image.shape[2]
|
||||
batch_size = image.shape[0]
|
||||
|
||||
# Send progress text to display size on the node
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id)
|
||||
|
||||
return width, height, batch_size
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop,
|
||||
"RepeatImageBatch": RepeatImageBatch,
|
||||
@ -500,4 +532,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||
"SaveSVGNode": SaveSVGNode,
|
||||
"ImageStitch": ImageStitch,
|
||||
"GetImageSize": GetImageSize,
|
||||
}
|
||||
|
@ -189,7 +189,7 @@ class ModelSamplingContinuousEDM:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps"],),
|
||||
"sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps", "cosmos_rflow"],),
|
||||
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||
}}
|
||||
@ -202,6 +202,7 @@ class ModelSamplingContinuousEDM:
|
||||
def patch(self, model, sampling, sigma_max, sigma_min):
|
||||
m = model.clone()
|
||||
|
||||
sampling_base = comfy.model_sampling.ModelSamplingContinuousEDM
|
||||
latent_format = None
|
||||
sigma_data = 1.0
|
||||
if sampling == "eps":
|
||||
@ -215,8 +216,11 @@ class ModelSamplingContinuousEDM:
|
||||
sampling_type = comfy.model_sampling.EDM
|
||||
sigma_data = 0.5
|
||||
latent_format = comfy.latent_formats.SDXL_Playground_2_5()
|
||||
elif sampling == "cosmos_rflow":
|
||||
sampling_type = comfy.model_sampling.COSMOS_RFLOW
|
||||
sampling_base = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||
|
||||
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
|
709
comfy_extras/nodes_train.py
Normal file
709
comfy_extras/nodes_train.py
Normal file
@ -0,0 +1,709 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import safetensors
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
import torch.utils.checkpoint
|
||||
import tqdm
|
||||
|
||||
import comfy.samplers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy.cli_args import args
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy.weight_adapter import adapters
|
||||
|
||||
|
||||
class TrainSampler(comfy.samplers.Sampler):
|
||||
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
self.loss_callback = loss_callback
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
self.optimizer.zero_grad()
|
||||
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False)
|
||||
latent = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
torch.zeros_like(sigmas),
|
||||
torch.zeros_like(noise, requires_grad=True),
|
||||
latent_image,
|
||||
False
|
||||
)
|
||||
|
||||
# Ensure model is in training mode and computing gradients
|
||||
# x0 pred
|
||||
denoised = model_wrap(noise, sigmas, **extra_args)
|
||||
try:
|
||||
loss = self.loss_fn(denoised, latent.clone())
|
||||
except RuntimeError as e:
|
||||
if "does not require grad and does not have a grad_fn" in str(e):
|
||||
logging.info("WARNING: This is likely due to the model is loaded in inference mode.")
|
||||
loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
|
||||
self.optimizer.step()
|
||||
# torch.cuda.memory._dump_snapshot("trainn.pickle")
|
||||
# torch.cuda.memory._record_memory_history(enabled=None)
|
||||
return torch.zeros_like(latent_image)
|
||||
|
||||
|
||||
class BiasDiff(torch.nn.Module):
|
||||
def __init__(self, bias):
|
||||
super().__init__()
|
||||
self.bias = bias
|
||||
|
||||
def __call__(self, b):
|
||||
org_dtype = b.dtype
|
||||
return (b.to(self.bias) + self.bias).to(org_dtype)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
return self.bias.nelement() * self.bias.element_size()
|
||||
|
||||
def move_to(self, device):
|
||||
self.to(device=device)
|
||||
return self.passive_memory_usage()
|
||||
|
||||
|
||||
def load_and_process_images(image_files, input_dir, resize_method="None"):
|
||||
"""Utility function to load and process a list of images.
|
||||
|
||||
Args:
|
||||
image_files: List of image filenames
|
||||
input_dir: Base directory containing the images
|
||||
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Batch of processed images
|
||||
"""
|
||||
if not image_files:
|
||||
raise ValueError("No valid images found in input")
|
||||
|
||||
output_images = []
|
||||
w, h = None, None
|
||||
|
||||
for file in image_files:
|
||||
image_path = os.path.join(input_dir, file)
|
||||
img = node_helpers.pillow(Image.open, image_path)
|
||||
|
||||
if img.mode == "I":
|
||||
img = img.point(lambda i: i * (1 / 255))
|
||||
img = img.convert("RGB")
|
||||
|
||||
if w is None and h is None:
|
||||
w, h = img.size[0], img.size[1]
|
||||
|
||||
# Resize image to first image
|
||||
if img.size[0] != w or img.size[1] != h:
|
||||
if resize_method == "Stretch":
|
||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||
elif resize_method == "Crop":
|
||||
img = img.crop((0, 0, w, h))
|
||||
elif resize_method == "Pad":
|
||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||
elif resize_method == "None":
|
||||
raise ValueError(
|
||||
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
|
||||
)
|
||||
|
||||
img_array = np.array(img).astype(np.float32) / 255.0
|
||||
img_tensor = torch.from_numpy(img_array)[None,]
|
||||
output_images.append(img_tensor)
|
||||
|
||||
return torch.cat(output_images, dim=0)
|
||||
|
||||
|
||||
class LoadImageSetNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"images": (
|
||||
[
|
||||
f
|
||||
for f in os.listdir(folder_paths.get_input_directory())
|
||||
if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"))
|
||||
],
|
||||
{"image_upload": True, "allow_batch": True},
|
||||
)
|
||||
},
|
||||
"optional": {
|
||||
"resize_method": (
|
||||
["None", "Stretch", "Crop", "Pad"],
|
||||
{"default": "None"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
INPUT_IS_LIST = True
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "load_images"
|
||||
CATEGORY = "loaders"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "Loads a batch of images from a directory for training."
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, images, resize_method):
|
||||
filenames = images[0] if isinstance(images[0], list) else images
|
||||
|
||||
for image in filenames:
|
||||
if not folder_paths.exists_annotated_filepath(image):
|
||||
return "Invalid image file: {}".format(image)
|
||||
return True
|
||||
|
||||
def load_images(self, input_files, resize_method):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"]
|
||||
image_files = [
|
||||
f
|
||||
for f in input_files
|
||||
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||
]
|
||||
output_tensor = load_and_process_images(image_files, input_dir, resize_method)
|
||||
return (output_tensor,)
|
||||
|
||||
|
||||
class LoadImageSetFromFolderNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."})
|
||||
},
|
||||
"optional": {
|
||||
"resize_method": (
|
||||
["None", "Stretch", "Crop", "Pad"],
|
||||
{"default": "None"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "load_images"
|
||||
CATEGORY = "loaders"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "Loads a batch of images from a directory for training."
|
||||
|
||||
def load_images(self, folder, resize_method):
|
||||
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
||||
image_files = [
|
||||
f
|
||||
for f in os.listdir(sub_input_dir)
|
||||
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||
]
|
||||
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method)
|
||||
return (output_tensor,)
|
||||
|
||||
|
||||
def draw_loss_graph(loss_map, steps):
|
||||
width, height = 500, 300
|
||||
img = Image.new("RGB", (width, height), "white")
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
|
||||
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()]
|
||||
|
||||
prev_point = (0, height - int(scaled_loss[0] * height))
|
||||
for i, l in enumerate(scaled_loss[1:], start=1):
|
||||
x = int(i / (steps - 1) * width)
|
||||
y = height - int(l * height)
|
||||
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||
prev_point = (x, y)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None):
|
||||
if result is None:
|
||||
result = []
|
||||
elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
|
||||
result.append(model)
|
||||
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
|
||||
return result
|
||||
name = name or "root"
|
||||
for next_name, child in model.named_children():
|
||||
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
|
||||
return result
|
||||
|
||||
|
||||
def patch(m):
|
||||
if not hasattr(m, "forward"):
|
||||
return
|
||||
org_forward = m.forward
|
||||
def fwd(args, kwargs):
|
||||
return org_forward(*args, **kwargs)
|
||||
def checkpointing_fwd(*args, **kwargs):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
fwd, args, kwargs, use_reentrant=False
|
||||
)
|
||||
m.org_forward = org_forward
|
||||
m.forward = checkpointing_fwd
|
||||
|
||||
|
||||
def unpatch(m):
|
||||
if hasattr(m, "org_forward"):
|
||||
m.forward = m.org_forward
|
||||
del m.org_forward
|
||||
|
||||
|
||||
class TrainLoraNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}),
|
||||
"latents": (
|
||||
"LATENT",
|
||||
{
|
||||
"tooltip": "The Latents to use for training, serve as dataset/input of the model."
|
||||
},
|
||||
),
|
||||
"positive": (
|
||||
IO.CONDITIONING,
|
||||
{"tooltip": "The positive conditioning to use for training."},
|
||||
),
|
||||
"batch_size": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 10000,
|
||||
"step": 1,
|
||||
"tooltip": "The batch size to use for training.",
|
||||
},
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 16,
|
||||
"min": 1,
|
||||
"max": 100000,
|
||||
"tooltip": "The number of steps to train the LoRA for.",
|
||||
},
|
||||
),
|
||||
"learning_rate": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.0005,
|
||||
"min": 0.0000001,
|
||||
"max": 1.0,
|
||||
"step": 0.000001,
|
||||
"tooltip": "The learning rate to use for training.",
|
||||
},
|
||||
),
|
||||
"rank": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 8,
|
||||
"min": 1,
|
||||
"max": 128,
|
||||
"tooltip": "The rank of the LoRA layers.",
|
||||
},
|
||||
),
|
||||
"optimizer": (
|
||||
["AdamW", "Adam", "SGD", "RMSprop"],
|
||||
{
|
||||
"default": "AdamW",
|
||||
"tooltip": "The optimizer to use for training.",
|
||||
},
|
||||
),
|
||||
"loss_function": (
|
||||
["MSE", "L1", "Huber", "SmoothL1"],
|
||||
{
|
||||
"default": "MSE",
|
||||
"tooltip": "The loss function to use for training.",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
|
||||
},
|
||||
),
|
||||
"training_dtype": (
|
||||
["bf16", "fp32"],
|
||||
{"default": "bf16", "tooltip": "The dtype to use for training."},
|
||||
),
|
||||
"lora_dtype": (
|
||||
["bf16", "fp32"],
|
||||
{"default": "bf16", "tooltip": "The dtype to use for lora."},
|
||||
),
|
||||
"existing_lora": (
|
||||
folder_paths.get_filename_list("loras") + ["[None]"],
|
||||
{
|
||||
"default": "[None]",
|
||||
"tooltip": "The existing LoRA to append to. Set to None for new LoRA.",
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT)
|
||||
RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps")
|
||||
FUNCTION = "train"
|
||||
CATEGORY = "training"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def train(
|
||||
self,
|
||||
model,
|
||||
latents,
|
||||
positive,
|
||||
batch_size,
|
||||
steps,
|
||||
learning_rate,
|
||||
rank,
|
||||
optimizer,
|
||||
loss_function,
|
||||
seed,
|
||||
training_dtype,
|
||||
lora_dtype,
|
||||
existing_lora,
|
||||
):
|
||||
mp = model.clone()
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
|
||||
latents = latents["samples"].to(dtype)
|
||||
num_images = latents.shape[0]
|
||||
|
||||
with torch.inference_mode(False):
|
||||
lora_sd = {}
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
|
||||
# Load existing LoRA weights if provided
|
||||
existing_weights = {}
|
||||
existing_steps = 0
|
||||
if existing_lora != "[None]":
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
||||
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
||||
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
||||
if lora_path:
|
||||
existing_weights = comfy.utils.load_torch_file(lora_path)
|
||||
|
||||
all_weight_adapters = []
|
||||
for n, m in mp.model.named_modules():
|
||||
if hasattr(m, "weight_function"):
|
||||
if m.weight is not None:
|
||||
key = "{}.weight".format(n)
|
||||
shape = m.weight.shape
|
||||
if len(shape) >= 2:
|
||||
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||
dora_scale = existing_weights.get(
|
||||
f"{key}.dora_scale", None
|
||||
)
|
||||
for adapter_cls in adapters:
|
||||
existing_adapter = adapter_cls.load(
|
||||
n, existing_weights, alpha, dora_scale
|
||||
)
|
||||
if existing_adapter is not None:
|
||||
break
|
||||
else:
|
||||
# If no existing adapter found, use LoRA
|
||||
# We will add algo option in the future
|
||||
existing_adapter = None
|
||||
adapter_cls = adapters[0]
|
||||
|
||||
if existing_adapter is not None:
|
||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||
else:
|
||||
# Use LoRA with alpha=1.0 by default
|
||||
train_adapter = adapter_cls.create_train(
|
||||
m.weight, rank=rank, alpha=1.0
|
||||
).to(lora_dtype)
|
||||
for name, parameter in train_adapter.named_parameters():
|
||||
lora_sd[f"{n}.{name}"] = parameter
|
||||
|
||||
mp.add_weight_wrapper(key, train_adapter)
|
||||
all_weight_adapters.append(train_adapter)
|
||||
else:
|
||||
diff = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
m.weight.shape, dtype=lora_dtype, requires_grad=True
|
||||
)
|
||||
)
|
||||
diff_module = BiasDiff(diff)
|
||||
mp.add_weight_wrapper(key, BiasDiff(diff))
|
||||
all_weight_adapters.append(diff_module)
|
||||
lora_sd["{}.diff".format(n)] = diff
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
key = "{}.bias".format(n)
|
||||
bias = torch.nn.Parameter(
|
||||
torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True)
|
||||
)
|
||||
bias_module = BiasDiff(bias)
|
||||
lora_sd["{}.diff_b".format(n)] = bias
|
||||
mp.add_weight_wrapper(key, BiasDiff(bias))
|
||||
all_weight_adapters.append(bias_module)
|
||||
|
||||
if optimizer == "Adam":
|
||||
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "AdamW":
|
||||
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "SGD":
|
||||
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "RMSprop":
|
||||
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
|
||||
|
||||
# Setup loss function based on selection
|
||||
if loss_function == "MSE":
|
||||
criterion = torch.nn.MSELoss()
|
||||
elif loss_function == "L1":
|
||||
criterion = torch.nn.L1Loss()
|
||||
elif loss_function == "Huber":
|
||||
criterion = torch.nn.HuberLoss()
|
||||
elif loss_function == "SmoothL1":
|
||||
criterion = torch.nn.SmoothL1Loss()
|
||||
|
||||
# setup models
|
||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||
patch(m)
|
||||
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
||||
|
||||
# Setup sampler and guider like in test script
|
||||
loss_map = {"loss": []}
|
||||
def loss_callback(loss):
|
||||
loss_map["loss"].append(loss)
|
||||
pbar.set_postfix({"loss": f"{loss:.4f}"})
|
||||
train_sampler = TrainSampler(
|
||||
criterion, optimizer, loss_callback=loss_callback
|
||||
)
|
||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||
guider.set_conds(positive) # Set conditioning from input
|
||||
ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced()
|
||||
|
||||
# yoland: this currently resize to the first image in the dataset
|
||||
|
||||
# Training loop
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||
# Generate random sigma
|
||||
sigma = mp.model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
)
|
||||
sigma = torch.tensor([sigma])
|
||||
|
||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
|
||||
|
||||
indices = torch.randperm(num_images)[:batch_size]
|
||||
ss.sample(
|
||||
noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()}
|
||||
)
|
||||
finally:
|
||||
for m in mp.model.modules():
|
||||
unpatch(m)
|
||||
del ss, train_sampler, optimizer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for adapter in all_weight_adapters:
|
||||
adapter.requires_grad_(False)
|
||||
|
||||
for param in lora_sd:
|
||||
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
||||
|
||||
return (mp, lora_sd, loss_map, steps + existing_steps)
|
||||
|
||||
|
||||
class LoraModelLoader:
|
||||
def __init__(self):
|
||||
self.loaded_lora = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
|
||||
"lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
|
||||
FUNCTION = "load_lora_model"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
DESCRIPTION = "Load Trained LoRA weights from Train LoRA node."
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def load_lora_model(self, model, lora, strength_model):
|
||||
if strength_model == 0:
|
||||
return (model, )
|
||||
|
||||
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
|
||||
return (model_lora, )
|
||||
|
||||
|
||||
class SaveLoRA:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"lora": (
|
||||
IO.LORA_MODEL,
|
||||
{
|
||||
"tooltip": "The LoRA model to save. Do not use the model with LoRA layers."
|
||||
},
|
||||
),
|
||||
"prefix": (
|
||||
"STRING",
|
||||
{
|
||||
"default": "loras/ComfyUI_trained_lora",
|
||||
"tooltip": "The prefix to use for the saved LoRA file.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"forceInput": True,
|
||||
"tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
CATEGORY = "loaders"
|
||||
EXPERIMENTAL = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def save(self, lora, prefix, steps=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir)
|
||||
if steps is None:
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
else:
|
||||
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
safetensors.torch.save_file(lora, output_checkpoint)
|
||||
return {}
|
||||
|
||||
|
||||
class LossGraphNode:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_temp_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"loss": (IO.LOSS_MAP, {"default": {}}),
|
||||
"filename_prefix": (IO.STRING, {"default": "loss_graph"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "plot_loss"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "training"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "Plots the loss graph and saves it to the output directory."
|
||||
|
||||
def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
loss_values = loss["loss"]
|
||||
width, height = 800, 480
|
||||
margin = 40
|
||||
|
||||
img = Image.new(
|
||||
"RGB", (width + margin, height + margin), "white"
|
||||
) # Extend canvas
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
min_loss, max_loss = min(loss_values), max(loss_values)
|
||||
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values]
|
||||
|
||||
steps = len(loss_values)
|
||||
|
||||
prev_point = (margin, height - int(scaled_loss[0] * height))
|
||||
for i, l in enumerate(scaled_loss[1:], start=1):
|
||||
x = margin + int(i / steps * width) # Scale X properly
|
||||
y = height - int(l * height)
|
||||
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||
prev_point = (x, y)
|
||||
|
||||
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
|
||||
draw.line(
|
||||
[(margin, height), (width + margin, height)], fill="black", width=2
|
||||
) # X-axis
|
||||
|
||||
font = None
|
||||
try:
|
||||
font = ImageFont.truetype("arial.ttf", 12)
|
||||
except IOError:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
# Add axis labels
|
||||
draw.text((5, height // 2), "Loss", font=font, fill="black")
|
||||
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
|
||||
|
||||
# Add min/max loss values
|
||||
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
|
||||
draw.text(
|
||||
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
|
||||
)
|
||||
|
||||
metadata = None
|
||||
if not args.disable_metadata:
|
||||
metadata = PngInfo()
|
||||
if prompt is not None:
|
||||
metadata.add_text("prompt", json.dumps(prompt))
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
||||
|
||||
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
img.save(
|
||||
os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"),
|
||||
pnginfo=metadata,
|
||||
)
|
||||
return {
|
||||
"ui": {
|
||||
"images": [
|
||||
{
|
||||
"filename": f"{filename_prefix}_{date}.png",
|
||||
"subfolder": "",
|
||||
"type": "temp",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TrainLoraNode": TrainLoraNode,
|
||||
"SaveLoRANode": SaveLoRA,
|
||||
"LoraModelLoader": LoraModelLoader,
|
||||
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
|
||||
"LossGraphNode": LossGraphNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TrainLoraNode": "Train LoRA",
|
||||
"SaveLoRANode": "Save LoRA Weights",
|
||||
"LoraModelLoader": "Load LoRA Model",
|
||||
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
|
||||
"LossGraphNode": "Plot Loss Graph",
|
||||
}
|
@ -23,6 +23,10 @@ class WebcamCapture(nodes.LoadImage):
|
||||
def load_capture(self, image, **kwargs):
|
||||
return super().load_image(folder_paths.get_annotated_filepath(image))
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, image, width, height, capture_on_queue):
|
||||
return super().IS_CHANGED(image)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WebcamCapture": WebcamCapture,
|
||||
|
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.39"
|
||||
__version__ = "0.3.41"
|
||||
|
28
execution.py
28
execution.py
@ -1,23 +1,35 @@
|
||||
import sys
|
||||
import copy
|
||||
import logging
|
||||
import threading
|
||||
import heapq
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
import inspect
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import nodes
|
||||
|
||||
import comfy.model_management
|
||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
import nodes
|
||||
from comfy_execution.caching import (
|
||||
CacheKeySetID,
|
||||
CacheKeySetInputSignature,
|
||||
DependencyAwareCache,
|
||||
HierarchicalCache,
|
||||
LRUCache,
|
||||
)
|
||||
from comfy_execution.graph import (
|
||||
DynamicPrompt,
|
||||
ExecutionBlocker,
|
||||
ExecutionList,
|
||||
get_input_info,
|
||||
)
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
SUCCESS = 0
|
||||
FAILURE = 1
|
||||
|
@ -276,6 +276,9 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
|
||||
|
||||
|
||||
def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||
"""
|
||||
Get the full path of a file in a folder, has to be a file
|
||||
"""
|
||||
global folder_names_and_paths
|
||||
folder_name = map_legacy(folder_name)
|
||||
if folder_name not in folder_names_and_paths:
|
||||
@ -293,6 +296,9 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||
|
||||
|
||||
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||
"""
|
||||
Get the full path of a file in a folder, has to be a file
|
||||
"""
|
||||
full_path = get_full_path(folder_name, filename)
|
||||
if full_path is None:
|
||||
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
||||
@ -394,3 +400,26 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
|
||||
os.makedirs(full_output_folder, exist_ok=True)
|
||||
counter = 1
|
||||
return full_output_folder, filename, counter, subfolder, filename_prefix
|
||||
|
||||
def get_input_subfolders() -> list[str]:
|
||||
"""Returns a list of all subfolder paths in the input directory, recursively.
|
||||
|
||||
Returns:
|
||||
List of folder paths relative to the input directory, excluding the root directory
|
||||
"""
|
||||
input_dir = get_input_directory()
|
||||
folders = []
|
||||
|
||||
try:
|
||||
if not os.path.exists(input_dir):
|
||||
return []
|
||||
|
||||
for root, dirs, _ in os.walk(input_dir):
|
||||
rel_path = os.path.relpath(root, input_dir)
|
||||
if rel_path != ".": # Only include non-root directories
|
||||
# Normalize path separators to forward slashes
|
||||
folders.append(rel_path.replace(os.sep, '/'))
|
||||
|
||||
return sorted(folders)
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
|
14
main.py
14
main.py
@ -17,7 +17,6 @@ if __name__ == "__main__":
|
||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||
os.environ['DO_NOT_TRACK'] = '1'
|
||||
|
||||
|
||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||
|
||||
def apply_custom_paths():
|
||||
@ -238,6 +237,15 @@ def cleanup_temp():
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def setup_database():
|
||||
try:
|
||||
from app.database.db import init_db, dependencies_available
|
||||
if dependencies_available():
|
||||
init_db()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||
|
||||
|
||||
def start_comfyui(asyncio_loop=None):
|
||||
"""
|
||||
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
|
||||
@ -266,6 +274,7 @@ def start_comfyui(asyncio_loop=None):
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
cuda_malloc_warning()
|
||||
setup_database()
|
||||
|
||||
prompt_server.add_routes()
|
||||
hijack_progress(prompt_server)
|
||||
@ -300,6 +309,9 @@ if __name__ == "__main__":
|
||||
logging.info("Python version: {}".format(sys.version))
|
||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||
|
||||
if sys.version_info.major == 3 and sys.version_info.minor < 10:
|
||||
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
|
||||
|
||||
event_loop, _, start_all_func = start_comfyui()
|
||||
try:
|
||||
x = start_all_func()
|
||||
|
21
nodes.py
21
nodes.py
@ -2067,6 +2067,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageQuantize": "Image Quantize",
|
||||
"ImageSharpen": "Image Sharpen",
|
||||
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
||||
"GetImageSize": "Get Image Size",
|
||||
# _for_testing
|
||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||
@ -2124,6 +2125,25 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
|
||||
|
||||
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
|
||||
|
||||
try:
|
||||
from comfy_config import config_parser
|
||||
|
||||
project_config = config_parser.extract_node_configuration(module_path)
|
||||
|
||||
web_dir_name = project_config.tool_comfy.web
|
||||
|
||||
if web_dir_name:
|
||||
web_dir_path = os.path.join(module_path, web_dir_name)
|
||||
|
||||
if os.path.isdir(web_dir_path):
|
||||
project_name = project_config.project.name
|
||||
|
||||
EXTENSION_WEB_DIRS[project_name] = web_dir_path
|
||||
|
||||
logging.info("Automatically register web folder {} for {}".format(web_dir_name, project_name))
|
||||
except Exception as e:
|
||||
logging.warning(f"Unable to parse pyproject.toml due to lack dependency pydantic-settings, please run 'pip install -r requirements.txt': {e}")
|
||||
|
||||
if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None:
|
||||
web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY")))
|
||||
if os.path.isdir(web_dir):
|
||||
@ -2211,6 +2231,7 @@ def init_builtin_extra_nodes():
|
||||
"nodes_model_downscale.py",
|
||||
"nodes_images.py",
|
||||
"nodes_video_model.py",
|
||||
"nodes_train.py",
|
||||
"nodes_sag.py",
|
||||
"nodes_perpneg.py",
|
||||
"nodes_stable3d.py",
|
||||
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.39"
|
||||
version = "0.3.41"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.21.3
|
||||
comfyui-workflow-templates==0.1.23
|
||||
comfyui-embedded-docs==0.2.0
|
||||
comfyui-frontend-package==1.21.7
|
||||
comfyui-workflow-templates==0.1.28
|
||||
comfyui-embedded-docs==0.2.2
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
@ -18,6 +18,8 @@ Pillow
|
||||
scipy
|
||||
tqdm
|
||||
psutil
|
||||
alembic
|
||||
SQLAlchemy
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
@ -25,3 +27,4 @@ spandrel
|
||||
soundfile
|
||||
av>=14.2.0
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
|
@ -390,7 +390,7 @@ class PromptServer():
|
||||
async def view_image(request):
|
||||
if "filename" in request.rel_url.query:
|
||||
filename = request.rel_url.query["filename"]
|
||||
filename,output_dir = folder_paths.annotated_filepath(filename)
|
||||
filename, output_dir = folder_paths.annotated_filepath(filename)
|
||||
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
@ -476,9 +476,8 @@ class PromptServer():
|
||||
# Get content type from mimetype, defaulting to 'application/octet-stream'
|
||||
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
||||
|
||||
# For security, force certain extensions to download instead of display
|
||||
file_extension = os.path.splitext(filename)[1].lower()
|
||||
if file_extension in {'.html', '.htm', '.js', '.css'}:
|
||||
# For security, force certain mimetypes to download instead of display
|
||||
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
|
||||
content_type = 'application/octet-stream' # Forces download
|
||||
|
||||
return web.FileResponse(
|
||||
@ -789,7 +788,7 @@ class PromptServer():
|
||||
if hasattr(Image, 'Resampling'):
|
||||
resampling = Image.Resampling.BILINEAR
|
||||
else:
|
||||
resampling = Image.ANTIALIAS
|
||||
resampling = Image.Resampling.LANCZOS
|
||||
|
||||
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||
type_num = 1
|
||||
|
@ -5,7 +5,10 @@ from unittest.mock import patch, MagicMock
|
||||
mock_nodes = MagicMock()
|
||||
mock_nodes.MAX_RESOLUTION = 16384
|
||||
|
||||
with patch.dict('sys.modules', {'nodes': mock_nodes}):
|
||||
# Mock server module for PromptServer
|
||||
mock_server = MagicMock()
|
||||
|
||||
with patch.dict('sys.modules', {'nodes': mock_nodes, 'server': mock_server}):
|
||||
from comfy_extras.nodes_images import ImageStitch
|
||||
|
||||
|
||||
|
51
tests-unit/folder_paths_test/misc_test.py
Normal file
51
tests-unit/folder_paths_test/misc_test.py
Normal file
@ -0,0 +1,51 @@
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from folder_paths import get_input_subfolders, set_input_directory
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mock_folder_structure():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create a nested folder structure
|
||||
folders = [
|
||||
"folder1",
|
||||
"folder1/subfolder1",
|
||||
"folder1/subfolder2",
|
||||
"folder2",
|
||||
"folder2/deep",
|
||||
"folder2/deep/nested",
|
||||
"empty_folder"
|
||||
]
|
||||
|
||||
# Create the folders
|
||||
for folder in folders:
|
||||
os.makedirs(os.path.join(temp_dir, folder))
|
||||
|
||||
# Add some files to test they're not included
|
||||
with open(os.path.join(temp_dir, "root_file.txt"), "w") as f:
|
||||
f.write("test")
|
||||
with open(os.path.join(temp_dir, "folder1", "test.txt"), "w") as f:
|
||||
f.write("test")
|
||||
|
||||
set_input_directory(temp_dir)
|
||||
yield temp_dir
|
||||
|
||||
|
||||
def test_gets_all_folders(mock_folder_structure):
|
||||
folders = get_input_subfolders()
|
||||
expected = ["folder1", "folder1/subfolder1", "folder1/subfolder2",
|
||||
"folder2", "folder2/deep", "folder2/deep/nested", "empty_folder"]
|
||||
assert sorted(folders) == sorted(expected)
|
||||
|
||||
|
||||
def test_handles_nonexistent_input_directory():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
nonexistent = os.path.join(temp_dir, "nonexistent")
|
||||
set_input_directory(nonexistent)
|
||||
assert get_input_subfolders() == []
|
||||
|
||||
|
||||
def test_empty_input_directory():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
set_input_directory(temp_dir)
|
||||
assert get_input_subfolders() == [] # Empty since we don't include root
|
18
utils/install_util.py
Normal file
18
utils/install_util.py
Normal file
@ -0,0 +1,18 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# The path to the requirements.txt file
|
||||
requirements_path = Path(__file__).parents[1] / "requirements.txt"
|
||||
|
||||
|
||||
def get_missing_requirements_message():
|
||||
"""The warning message to display when a package is missing."""
|
||||
|
||||
extra = ""
|
||||
if sys.flags.no_user_site:
|
||||
extra = "-s "
|
||||
return f"""
|
||||
Please install the updated requirements.txt file by running:
|
||||
{sys.executable} {extra}-m pip install -r {requirements_path}
|
||||
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem.
|
||||
""".strip()
|
Loading…
x
Reference in New Issue
Block a user