mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-28 00:36:32 +00:00
Merge branch 'master' into worksplit-multigpu
This commit is contained in:
commit
1ae98932f1
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -15,6 +15,14 @@ body:
|
|||||||
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||||
|
|
||||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||||
|
- type: checkboxes
|
||||||
|
id: custom-nodes-test
|
||||||
|
attributes:
|
||||||
|
label: Custom Node Testing
|
||||||
|
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||||
|
options:
|
||||||
|
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||||
|
required: true
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: Expected Behavior
|
label: Expected Behavior
|
||||||
|
8
.github/ISSUE_TEMPLATE/user-support.yml
vendored
8
.github/ISSUE_TEMPLATE/user-support.yml
vendored
@ -11,6 +11,14 @@ body:
|
|||||||
**2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics.
|
**2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics.
|
||||||
|
|
||||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||||
|
- type: checkboxes
|
||||||
|
id: custom-nodes-test
|
||||||
|
attributes:
|
||||||
|
label: Custom Node Testing
|
||||||
|
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||||
|
options:
|
||||||
|
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||||
|
required: true
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: Your question
|
label: Your question
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
[![Website][website-shield]][website-url]
|
[![Website][website-shield]][website-url]
|
||||||
[![Dynamic JSON Badge][discord-shield]][discord-url]
|
[![Dynamic JSON Badge][discord-shield]][discord-url]
|
||||||
|
[![Twitter][twitter-shield]][twitter-url]
|
||||||
[![Matrix][matrix-shield]][matrix-url]
|
[![Matrix][matrix-shield]][matrix-url]
|
||||||
<br>
|
<br>
|
||||||
[![][github-release-shield]][github-release-link]
|
[![][github-release-shield]][github-release-link]
|
||||||
@ -20,6 +21,8 @@
|
|||||||
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
|
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
|
||||||
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
|
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
|
||||||
[discord-url]: https://www.comfy.org/discord
|
[discord-url]: https://www.comfy.org/discord
|
||||||
|
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
|
||||||
|
[twitter-url]: https://x.com/ComfyUI
|
||||||
|
|
||||||
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
|
[github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver
|
||||||
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
[github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||||
@ -62,12 +65,13 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||||
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
|
||||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||||
|
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||||
- Video Models
|
- Video Models
|
||||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
|
||||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||||
- Audio Models
|
- Audio Models
|
||||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||||
@ -95,7 +99,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||||
- Starts up very fast.
|
- Starts up very fast.
|
||||||
- Works fully offline: will never download anything.
|
- Works fully offline: core will never download anything unless you want to.
|
||||||
|
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
|
||||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||||
|
|
||||||
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||||
|
84
alembic.ini
Normal file
84
alembic.ini
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# A generic, single database configuration.
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
# path to migration scripts
|
||||||
|
# Use forward slashes (/) also on windows to provide an os agnostic path
|
||||||
|
script_location = alembic_db
|
||||||
|
|
||||||
|
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||||
|
# Uncomment the line below if you want the files to be prepended with date and time
|
||||||
|
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||||
|
# for all available tokens
|
||||||
|
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# sys.path path, will be prepended to sys.path if present.
|
||||||
|
# defaults to the current working directory.
|
||||||
|
prepend_sys_path = .
|
||||||
|
|
||||||
|
# timezone to use when rendering the date within the migration file
|
||||||
|
# as well as the filename.
|
||||||
|
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||||
|
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||||
|
# string value is passed to ZoneInfo()
|
||||||
|
# leave blank for localtime
|
||||||
|
# timezone =
|
||||||
|
|
||||||
|
# max length of characters to apply to the "slug" field
|
||||||
|
# truncate_slug_length = 40
|
||||||
|
|
||||||
|
# set to 'true' to run the environment during
|
||||||
|
# the 'revision' command, regardless of autogenerate
|
||||||
|
# revision_environment = false
|
||||||
|
|
||||||
|
# set to 'true' to allow .pyc and .pyo files without
|
||||||
|
# a source .py file to be detected as revisions in the
|
||||||
|
# versions/ directory
|
||||||
|
# sourceless = false
|
||||||
|
|
||||||
|
# version location specification; This defaults
|
||||||
|
# to alembic_db/versions. When using multiple version
|
||||||
|
# directories, initial revisions must be specified with --version-path.
|
||||||
|
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||||
|
# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions
|
||||||
|
|
||||||
|
# version path separator; As mentioned above, this is the character used to split
|
||||||
|
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||||
|
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||||
|
# Valid values for version_path_separator are:
|
||||||
|
#
|
||||||
|
# version_path_separator = :
|
||||||
|
# version_path_separator = ;
|
||||||
|
# version_path_separator = space
|
||||||
|
# version_path_separator = newline
|
||||||
|
#
|
||||||
|
# Use os.pathsep. Default configuration used for new projects.
|
||||||
|
version_path_separator = os
|
||||||
|
|
||||||
|
# set to 'true' to search source files recursively
|
||||||
|
# in each "version_locations" directory
|
||||||
|
# new in Alembic version 1.10
|
||||||
|
# recursive_version_locations = false
|
||||||
|
|
||||||
|
# the output encoding used when revision files
|
||||||
|
# are written from script.py.mako
|
||||||
|
# output_encoding = utf-8
|
||||||
|
|
||||||
|
sqlalchemy.url = sqlite:///user/comfyui.db
|
||||||
|
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
# post_write_hooks defines scripts or Python functions that are run
|
||||||
|
# on newly generated revision scripts. See the documentation for further
|
||||||
|
# detail and examples
|
||||||
|
|
||||||
|
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||||
|
# hooks = black
|
||||||
|
# black.type = console_scripts
|
||||||
|
# black.entrypoint = black
|
||||||
|
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||||
|
# hooks = ruff
|
||||||
|
# ruff.type = exec
|
||||||
|
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||||
|
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
4
alembic_db/README.md
Normal file
4
alembic_db/README.md
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
## Generate new revision
|
||||||
|
|
||||||
|
1. Update models in `/app/database/models.py`
|
||||||
|
2. Run `alembic revision --autogenerate -m "{your message}"`
|
64
alembic_db/env.py
Normal file
64
alembic_db/env.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from sqlalchemy import engine_from_config
|
||||||
|
from sqlalchemy import pool
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
|
||||||
|
# this is the Alembic Config object, which provides
|
||||||
|
# access to the values within the .ini file in use.
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
|
||||||
|
from app.database.models import Base
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
# other values from the config, defined by the needs of env.py,
|
||||||
|
# can be acquired:
|
||||||
|
# my_important_option = config.get_main_option("my_important_option")
|
||||||
|
# ... etc.
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode.
|
||||||
|
This configures the context with just a URL
|
||||||
|
and not an Engine, though an Engine is acceptable
|
||||||
|
here as well. By skipping the Engine creation
|
||||||
|
we don't even need a DBAPI to be available.
|
||||||
|
Calls to context.execute() here emit the given string to the
|
||||||
|
script output.
|
||||||
|
"""
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode.
|
||||||
|
In this scenario we need to create an Engine
|
||||||
|
and associate a connection with the context.
|
||||||
|
"""
|
||||||
|
connectable = engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
context.configure(
|
||||||
|
connection=connection, target_metadata=target_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
28
alembic_db/script.py.mako
Normal file
28
alembic_db/script.py.mako
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
${downgrades if downgrades else "pass"}
|
112
app/database/db.py
Normal file
112
app/database/db.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from app.logger import log_startup_warning
|
||||||
|
from utils.install_util import get_missing_requirements_message
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
_DB_AVAILABLE = False
|
||||||
|
Session = None
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from alembic import command
|
||||||
|
from alembic.config import Config
|
||||||
|
from alembic.runtime.migration import MigrationContext
|
||||||
|
from alembic.script import ScriptDirectory
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
_DB_AVAILABLE = True
|
||||||
|
except ImportError as e:
|
||||||
|
log_startup_warning(
|
||||||
|
f"""
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
Error importing dependencies: {e}
|
||||||
|
{get_missing_requirements_message()}
|
||||||
|
This error is happening because ComfyUI now uses a local sqlite database.
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def dependencies_available():
|
||||||
|
"""
|
||||||
|
Temporary function to check if the dependencies are available
|
||||||
|
"""
|
||||||
|
return _DB_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
|
def can_create_session():
|
||||||
|
"""
|
||||||
|
Temporary function to check if the database is available to create a session
|
||||||
|
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
|
||||||
|
"""
|
||||||
|
return dependencies_available() and Session is not None
|
||||||
|
|
||||||
|
|
||||||
|
def get_alembic_config():
|
||||||
|
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
||||||
|
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
||||||
|
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||||
|
|
||||||
|
config = Config(config_path)
|
||||||
|
config.set_main_option("script_location", scripts_path)
|
||||||
|
config.set_main_option("sqlalchemy.url", args.database_url)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_path():
|
||||||
|
url = args.database_url
|
||||||
|
if url.startswith("sqlite:///"):
|
||||||
|
return url.split("///")[1]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||||
|
|
||||||
|
|
||||||
|
def init_db():
|
||||||
|
db_url = args.database_url
|
||||||
|
logging.debug(f"Database URL: {db_url}")
|
||||||
|
db_path = get_db_path()
|
||||||
|
db_exists = os.path.exists(db_path)
|
||||||
|
|
||||||
|
config = get_alembic_config()
|
||||||
|
|
||||||
|
# Check if we need to upgrade
|
||||||
|
engine = create_engine(db_url)
|
||||||
|
conn = engine.connect()
|
||||||
|
|
||||||
|
context = MigrationContext.configure(conn)
|
||||||
|
current_rev = context.get_current_revision()
|
||||||
|
|
||||||
|
script = ScriptDirectory.from_config(config)
|
||||||
|
target_rev = script.get_current_head()
|
||||||
|
|
||||||
|
if target_rev is None:
|
||||||
|
logging.warning("No target revision found.")
|
||||||
|
elif current_rev != target_rev:
|
||||||
|
# Backup the database pre upgrade
|
||||||
|
backup_path = db_path + ".bkp"
|
||||||
|
if db_exists:
|
||||||
|
shutil.copy(db_path, backup_path)
|
||||||
|
else:
|
||||||
|
backup_path = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
command.upgrade(config, target_rev)
|
||||||
|
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
||||||
|
except Exception as e:
|
||||||
|
if backup_path:
|
||||||
|
# Restore the database from backup if upgrade fails
|
||||||
|
shutil.copy(backup_path, db_path)
|
||||||
|
os.remove(backup_path)
|
||||||
|
logging.exception("Error upgrading database: ")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
global Session
|
||||||
|
Session = sessionmaker(bind=engine)
|
||||||
|
|
||||||
|
|
||||||
|
def create_session():
|
||||||
|
return Session()
|
14
app/database/models.py
Normal file
14
app/database/models.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from sqlalchemy.orm import declarative_base
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
def to_dict(obj):
|
||||||
|
fields = obj.__table__.columns.keys()
|
||||||
|
return {
|
||||||
|
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||||
|
for field in fields
|
||||||
|
if (val := getattr(obj, field))
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: Define models here
|
@ -16,26 +16,17 @@ from importlib.metadata import version
|
|||||||
import requests
|
import requests
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
|
from utils.install_util import get_missing_requirements_message, requirements_path
|
||||||
|
|
||||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||||
import app.logger
|
import app.logger
|
||||||
|
|
||||||
# The path to the requirements.txt file
|
|
||||||
req_path = Path(__file__).parents[1] / "requirements.txt"
|
|
||||||
|
|
||||||
|
|
||||||
def frontend_install_warning_message():
|
def frontend_install_warning_message():
|
||||||
"""The warning message to display when the frontend version is not up to date."""
|
|
||||||
|
|
||||||
extra = ""
|
|
||||||
if sys.flags.no_user_site:
|
|
||||||
extra = "-s "
|
|
||||||
return f"""
|
return f"""
|
||||||
Please install the updated requirements.txt file by running:
|
{get_missing_requirements_message()}
|
||||||
{sys.executable} {extra}-m pip install -r {req_path}
|
|
||||||
|
|
||||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||||
|
|
||||||
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
|
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
@ -48,7 +39,7 @@ def check_frontend_version():
|
|||||||
try:
|
try:
|
||||||
frontend_version_str = version("comfyui-frontend-package")
|
frontend_version_str = version("comfyui-frontend-package")
|
||||||
frontend_version = parse_version(frontend_version_str)
|
frontend_version = parse_version(frontend_version_str)
|
||||||
with open(req_path, "r", encoding="utf-8") as f:
|
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||||
required_frontend = parse_version(f.readline().split("=")[-1])
|
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||||
if frontend_version < required_frontend:
|
if frontend_version < required_frontend:
|
||||||
app.logger.log_startup_warning(
|
app.logger.log_startup_warning(
|
||||||
@ -121,9 +112,22 @@ class FrontEndProvider:
|
|||||||
response.raise_for_status() # Raises an HTTPError if the response was an error
|
response.raise_for_status() # Raises an HTTPError if the response was an error
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def latest_prerelease(self) -> Release:
|
||||||
|
"""Get the latest pre-release version - even if it's older than the latest release"""
|
||||||
|
release = [release for release in self.all_releases if release["prerelease"]]
|
||||||
|
|
||||||
|
if not release:
|
||||||
|
raise ValueError("No pre-releases found")
|
||||||
|
|
||||||
|
# GitHub returns releases in reverse chronological order, so first is latest
|
||||||
|
return release[0]
|
||||||
|
|
||||||
def get_release(self, version: str) -> Release:
|
def get_release(self, version: str) -> Release:
|
||||||
if version == "latest":
|
if version == "latest":
|
||||||
return self.latest_release
|
return self.latest_release
|
||||||
|
elif version == "prerelease":
|
||||||
|
return self.latest_prerelease
|
||||||
else:
|
else:
|
||||||
for release in self.all_releases:
|
for release in self.all_releases:
|
||||||
if release["tag_name"] in [version, f"v{version}"]:
|
if release["tag_name"] in [version, f"v{version}"]:
|
||||||
@ -230,7 +234,7 @@ comfyui-workflow-templates is not installed.
|
|||||||
Raises:
|
Raises:
|
||||||
argparse.ArgumentTypeError: If the version string is invalid.
|
argparse.ArgumentTypeError: If the version string is invalid.
|
||||||
"""
|
"""
|
||||||
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
|
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+[-._a-zA-Z0-9]*|latest|prerelease)$"
|
||||||
match_result = re.match(VERSION_PATTERN, value)
|
match_result = re.match(VERSION_PATTERN, value)
|
||||||
if match_result is None:
|
if match_result is None:
|
||||||
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
|
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
|
||||||
|
@ -203,6 +203,11 @@ parser.add_argument(
|
|||||||
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
database_default_path = os.path.abspath(
|
||||||
|
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||||
|
)
|
||||||
|
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
else:
|
else:
|
||||||
|
@ -37,6 +37,8 @@ class IO(StrEnum):
|
|||||||
CONTROL_NET = "CONTROL_NET"
|
CONTROL_NET = "CONTROL_NET"
|
||||||
VAE = "VAE"
|
VAE = "VAE"
|
||||||
MODEL = "MODEL"
|
MODEL = "MODEL"
|
||||||
|
LORA_MODEL = "LORA_MODEL"
|
||||||
|
LOSS_MAP = "LOSS_MAP"
|
||||||
CLIP_VISION = "CLIP_VISION"
|
CLIP_VISION = "CLIP_VISION"
|
||||||
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
|
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
|
||||||
STYLE_MODEL = "STYLE_MODEL"
|
STYLE_MODEL = "STYLE_MODEL"
|
||||||
|
@ -433,7 +433,8 @@ class ControlLora(ControlNet):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
for k in self.control_weights:
|
for k in self.control_weights:
|
||||||
if k not in {"lora_controlnet"}:
|
if (k not in {"lora_controlnet"}):
|
||||||
|
if (k.endswith(".up") or k.endswith(".down") or k.endswith(".weight") or k.endswith(".bias")) and ("__" not in k):
|
||||||
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
|
comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from scipy import integrate
|
from scipy import integrate
|
||||||
import torch
|
import torch
|
||||||
@ -142,6 +143,33 @@ class BrownianTreeNoiseSampler:
|
|||||||
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||||
|
|
||||||
|
|
||||||
|
def sigma_to_half_log_snr(sigma, model_sampling):
|
||||||
|
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
|
||||||
|
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
||||||
|
# log((1 - t) / t) = log((1 - sigma) / sigma)
|
||||||
|
return sigma.logit().neg()
|
||||||
|
return sigma.log().neg()
|
||||||
|
|
||||||
|
|
||||||
|
def half_log_snr_to_sigma(half_log_snr, model_sampling):
|
||||||
|
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
|
||||||
|
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
||||||
|
# 1 / (1 + exp(half_log_snr))
|
||||||
|
return half_log_snr.neg().sigmoid()
|
||||||
|
return half_log_snr.neg().exp()
|
||||||
|
|
||||||
|
|
||||||
|
def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
|
||||||
|
"""Adjust the first sigma to avoid invalid logSNR."""
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return sigmas
|
||||||
|
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
||||||
|
if sigmas[0] >= 1:
|
||||||
|
sigmas = sigmas.clone()
|
||||||
|
sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
|
||||||
|
return sigmas
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||||
@ -753,6 +781,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
"""DPM-Solver++(2M) SDE."""
|
"""DPM-Solver++(2M) SDE."""
|
||||||
@ -768,9 +797,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||||
|
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||||
|
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||||
|
|
||||||
old_denoised = None
|
old_denoised = None
|
||||||
h_last = None
|
h, h_last = None, None
|
||||||
h = None
|
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
@ -781,26 +813,29 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
x = denoised
|
x = denoised
|
||||||
else:
|
else:
|
||||||
# DPM-Solver++(2M) SDE
|
# DPM-Solver++(2M) SDE
|
||||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||||
h = s - t
|
h = lambda_t - lambda_s
|
||||||
eta_h = eta * h
|
h_eta = h * (eta + 1)
|
||||||
|
|
||||||
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||||
|
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
|
||||||
|
|
||||||
if old_denoised is not None:
|
if old_denoised is not None:
|
||||||
r = h_last / h
|
r = h_last / h
|
||||||
if solver_type == 'heun':
|
if solver_type == 'heun':
|
||||||
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised)
|
||||||
elif solver_type == 'midpoint':
|
elif solver_type == 'midpoint':
|
||||||
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||||
|
|
||||||
if eta:
|
if eta > 0 and s_noise > 0:
|
||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
||||||
|
|
||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
h_last = h
|
h_last = h
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""DPM-Solver++(3M) SDE."""
|
"""DPM-Solver++(3M) SDE."""
|
||||||
@ -814,6 +849,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||||
|
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||||
|
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||||
|
|
||||||
denoised_1, denoised_2 = None, None
|
denoised_1, denoised_2 = None, None
|
||||||
h, h_1, h_2 = None, None, None
|
h, h_1, h_2 = None, None, None
|
||||||
|
|
||||||
@ -825,13 +864,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
# Denoising step
|
# Denoising step
|
||||||
x = denoised
|
x = denoised
|
||||||
else:
|
else:
|
||||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||||
h = s - t
|
h = lambda_t - lambda_s
|
||||||
h_eta = h * (eta + 1)
|
h_eta = h * (eta + 1)
|
||||||
|
|
||||||
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
|
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||||
|
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
|
||||||
|
|
||||||
if h_2 is not None:
|
if h_2 is not None:
|
||||||
|
# DPM-Solver++(3M) SDE
|
||||||
r0 = h_1 / h
|
r0 = h_1 / h
|
||||||
r1 = h_2 / h
|
r1 = h_2 / h
|
||||||
d1_0 = (denoised - denoised_1) / r0
|
d1_0 = (denoised - denoised_1) / r0
|
||||||
@ -840,20 +882,22 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
d2 = (d1_0 - d1_1) / (r0 + r1)
|
d2 = (d1_0 - d1_1) / (r0 + r1)
|
||||||
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
||||||
phi_3 = phi_2 / h_eta - 0.5
|
phi_3 = phi_2 / h_eta - 0.5
|
||||||
x = x + phi_2 * d1 - phi_3 * d2
|
x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2
|
||||||
elif h_1 is not None:
|
elif h_1 is not None:
|
||||||
|
# DPM-Solver++(2M) SDE
|
||||||
r = h_1 / h
|
r = h_1 / h
|
||||||
d = (denoised - denoised_1) / r
|
d = (denoised - denoised_1) / r
|
||||||
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
||||||
x = x + phi_2 * d
|
x = x + (alpha_t * phi_2) * d
|
||||||
|
|
||||||
if eta:
|
if eta > 0 and s_noise > 0:
|
||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
||||||
|
|
||||||
denoised_1, denoised_2 = denoised, denoised_1
|
denoised_1, denoised_2 = denoised, denoised_1
|
||||||
h_1, h_2 = h, h_1
|
h_1, h_2 = h, h_1
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
@ -863,6 +907,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
|||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
@ -872,6 +917,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
|||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
@ -1449,12 +1495,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||||
'''
|
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||||
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
|
arXiv: https://arxiv.org/abs/2305.14267
|
||||||
Arxiv: https://arxiv.org/abs/2305.14267
|
"""
|
||||||
'''
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
@ -1462,6 +1508,11 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
|
|
||||||
inject_noise = eta > 0 and s_noise > 0
|
inject_noise = eta > 0 and s_noise > 0
|
||||||
|
|
||||||
|
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||||
|
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||||
|
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||||
|
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
@ -1469,80 +1520,96 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
if sigmas[i + 1] == 0:
|
if sigmas[i + 1] == 0:
|
||||||
x = denoised
|
x = denoised
|
||||||
else:
|
else:
|
||||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||||
h = t_next - t
|
h = lambda_t - lambda_s
|
||||||
h_eta = h * (eta + 1)
|
h_eta = h * (eta + 1)
|
||||||
s = t + r * h
|
lambda_s_1 = lambda_s + r * h
|
||||||
fac = 1 / (2 * r)
|
fac = 1 / (2 * r)
|
||||||
sigma_s = s.neg().exp()
|
sigma_s_1 = sigma_fn(lambda_s_1)
|
||||||
|
|
||||||
|
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
||||||
|
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||||
|
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||||
|
|
||||||
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
|
# 0 < r < 1
|
||||||
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
||||||
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
|
||||||
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
|
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
|
||||||
|
|
||||||
# Step 1
|
# Step 1
|
||||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
||||||
if inject_noise:
|
|
||||||
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
|
|
||||||
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
|
|
||||||
|
|
||||||
# Step 2
|
|
||||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
|
||||||
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
|
|
||||||
if inject_noise:
|
|
||||||
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
|
||||||
return x
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
|
||||||
'''
|
|
||||||
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
|
|
||||||
Arxiv: https://arxiv.org/abs/2305.14267
|
|
||||||
'''
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
seed = extra_args.get("seed", None)
|
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
|
||||||
|
|
||||||
inject_noise = eta > 0 and s_noise > 0
|
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
|
||||||
if callback is not None:
|
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
|
||||||
if sigmas[i + 1] == 0:
|
|
||||||
x = denoised
|
|
||||||
else:
|
|
||||||
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
|
||||||
h = t_next - t
|
|
||||||
h_eta = h * (eta + 1)
|
|
||||||
s_1 = t + r_1 * h
|
|
||||||
s_2 = t + r_2 * h
|
|
||||||
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
|
|
||||||
|
|
||||||
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
|
||||||
if inject_noise:
|
|
||||||
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
|
||||||
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
|
|
||||||
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
|
||||||
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
|
||||||
|
|
||||||
# Step 1
|
|
||||||
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||||
|
|
||||||
# Step 2
|
# Step 2
|
||||||
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
|
||||||
|
if inject_noise:
|
||||||
|
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||||
|
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
|
||||||
|
arXiv: https://arxiv.org/abs/2305.14267
|
||||||
|
"""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
inject_noise = eta > 0 and s_noise > 0
|
||||||
|
|
||||||
|
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
||||||
|
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
||||||
|
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||||
|
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
if sigmas[i + 1] == 0:
|
||||||
|
x = denoised
|
||||||
|
else:
|
||||||
|
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
||||||
|
h = lambda_t - lambda_s
|
||||||
|
h_eta = h * (eta + 1)
|
||||||
|
lambda_s_1 = lambda_s + r_1 * h
|
||||||
|
lambda_s_2 = lambda_s + r_2 * h
|
||||||
|
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
|
||||||
|
|
||||||
|
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
||||||
|
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
||||||
|
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
|
||||||
|
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
||||||
|
|
||||||
|
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
||||||
|
if inject_noise:
|
||||||
|
# 0 < r_1 < r_2 < 1
|
||||||
|
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
||||||
|
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
|
||||||
|
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
|
||||||
|
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||||
|
|
||||||
|
# Step 1
|
||||||
|
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
||||||
|
if inject_noise:
|
||||||
|
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||||
|
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||||
|
|
||||||
|
# Step 2
|
||||||
|
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||||
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||||
|
|
||||||
# Step 3
|
# Step 3
|
||||||
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
||||||
return x
|
return x
|
||||||
|
@ -26,16 +26,6 @@ from torch import nn
|
|||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
|
||||||
t: torch.Tensor,
|
|
||||||
freqs: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
|
||||||
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
|
||||||
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
|
||||||
return t_out
|
|
||||||
|
|
||||||
|
|
||||||
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
||||||
if name == "I":
|
if name == "I":
|
||||||
return nn.Identity()
|
return nn.Identity()
|
||||||
|
@ -66,15 +66,16 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
|
|||||||
h_extrapolation_ratio: float = 1.0,
|
h_extrapolation_ratio: float = 1.0,
|
||||||
w_extrapolation_ratio: float = 1.0,
|
w_extrapolation_ratio: float = 1.0,
|
||||||
t_extrapolation_ratio: float = 1.0,
|
t_extrapolation_ratio: float = 1.0,
|
||||||
|
enable_fps_modulation: bool = True,
|
||||||
device=None,
|
device=None,
|
||||||
**kwargs, # used for compatibility with other positional embeddings; unused in this class
|
**kwargs, # used for compatibility with other positional embeddings; unused in this class
|
||||||
):
|
):
|
||||||
del kwargs
|
del kwargs
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
|
|
||||||
self.base_fps = base_fps
|
self.base_fps = base_fps
|
||||||
self.max_h = len_h
|
self.max_h = len_h
|
||||||
self.max_w = len_w
|
self.max_w = len_w
|
||||||
|
self.enable_fps_modulation = enable_fps_modulation
|
||||||
|
|
||||||
dim = head_dim
|
dim = head_dim
|
||||||
dim_h = dim // 6 * 2
|
dim_h = dim // 6 * 2
|
||||||
@ -132,21 +133,19 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
|
|||||||
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
|
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
|
||||||
|
|
||||||
B, T, H, W, _ = B_T_H_W_C
|
B, T, H, W, _ = B_T_H_W_C
|
||||||
|
seq = torch.arange(max(H, W, T), dtype=torch.float, device=device)
|
||||||
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
|
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
|
||||||
assert (
|
assert (
|
||||||
uniform_fps or B == 1 or T == 1
|
uniform_fps or B == 1 or T == 1
|
||||||
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
|
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
|
||||||
assert (
|
half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs)
|
||||||
H <= self.max_h and W <= self.max_w
|
half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs)
|
||||||
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
|
|
||||||
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
|
|
||||||
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
|
|
||||||
|
|
||||||
# apply sequence scaling in temporal dimension
|
# apply sequence scaling in temporal dimension
|
||||||
if fps is None: # image case
|
if fps is None or self.enable_fps_modulation is False: # image case
|
||||||
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
|
half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs)
|
||||||
else:
|
else:
|
||||||
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
||||||
|
|
||||||
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
|
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
|
||||||
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
|
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
|
||||||
|
864
comfy/ldm/cosmos/predict2.py
Normal file
864
comfy/ldm/cosmos/predict2.py
Normal file
@ -0,0 +1,864 @@
|
|||||||
|
# original code from: https://github.com/nvidia-cosmos/cosmos-predict2
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from einops import rearrange
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
import logging
|
||||||
|
from typing import Callable, Optional, Tuple
|
||||||
|
import math
|
||||||
|
|
||||||
|
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(
|
||||||
|
t: torch.Tensor,
|
||||||
|
freqs: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||||
|
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||||
|
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||||
|
return t_out
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------- Feed Forward Network -----------------------
|
||||||
|
class GPT2FeedForward(nn.Module):
|
||||||
|
def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.activation = nn.GELU()
|
||||||
|
self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
|
||||||
|
self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self._layer_id = None
|
||||||
|
self._dim = d_model
|
||||||
|
self._hidden_dim = d_ff
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.layer1(x)
|
||||||
|
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Computes multi-head attention using PyTorch's native implementation.
|
||||||
|
|
||||||
|
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
||||||
|
It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product
|
||||||
|
attention, and rearranges the output back to the original format.
|
||||||
|
|
||||||
|
The input tensor names use the following dimension conventions:
|
||||||
|
|
||||||
|
- B: batch size
|
||||||
|
- S: sequence length
|
||||||
|
- H: number of attention heads
|
||||||
|
- D: head dimension
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)
|
||||||
|
k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)
|
||||||
|
v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output tensor with shape (batch, seq_len, n_heads * head_dim)
|
||||||
|
"""
|
||||||
|
in_q_shape = q_B_S_H_D.shape
|
||||||
|
in_k_shape = k_B_S_H_D.shape
|
||||||
|
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
||||||
|
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||||
|
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||||
|
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""
|
||||||
|
A flexible attention module supporting both self-attention and cross-attention mechanisms.
|
||||||
|
|
||||||
|
This module implements a multi-head attention layer that can operate in either self-attention
|
||||||
|
or cross-attention mode. The mode is determined by whether a context dimension is provided.
|
||||||
|
The implementation uses scaled dot-product attention and supports optional bias terms and
|
||||||
|
dropout regularization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_dim (int): The dimensionality of the query vectors.
|
||||||
|
context_dim (int, optional): The dimensionality of the context (key/value) vectors.
|
||||||
|
If None, the module operates in self-attention mode using query_dim. Default: None
|
||||||
|
n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8
|
||||||
|
head_dim (int, optional): The dimension of each attention head. Default: 64
|
||||||
|
dropout (float, optional): Dropout probability applied to the output. Default: 0.0
|
||||||
|
qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd"
|
||||||
|
backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine"
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Self-attention with 512 dimensions and 8 heads
|
||||||
|
>>> self_attn = Attention(query_dim=512)
|
||||||
|
>>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim)
|
||||||
|
>>> out = self_attn(x) # (32, 16, 512)
|
||||||
|
|
||||||
|
>>> # Cross-attention
|
||||||
|
>>> cross_attn = Attention(query_dim=512, context_dim=256)
|
||||||
|
>>> query = torch.randn(32, 16, 512)
|
||||||
|
>>> context = torch.randn(32, 8, 256)
|
||||||
|
>>> out = cross_attn(query, context) # (32, 16, 512)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim: int,
|
||||||
|
context_dim: Optional[int] = None,
|
||||||
|
n_heads: int = 8,
|
||||||
|
head_dim: int = 64,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
logging.debug(
|
||||||
|
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||||
|
f"{n_heads} heads with a dimension of {head_dim}."
|
||||||
|
)
|
||||||
|
self.is_selfattn = context_dim is None # self attention
|
||||||
|
|
||||||
|
context_dim = query_dim if context_dim is None else context_dim
|
||||||
|
inner_dim = head_dim * n_heads
|
||||||
|
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.query_dim = query_dim
|
||||||
|
self.context_dim = context_dim
|
||||||
|
|
||||||
|
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
self.v_norm = nn.Identity()
|
||||||
|
|
||||||
|
self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
|
||||||
|
|
||||||
|
self.attn_op = torch_attention_op
|
||||||
|
|
||||||
|
self._query_dim = query_dim
|
||||||
|
self._context_dim = context_dim
|
||||||
|
self._inner_dim = inner_dim
|
||||||
|
|
||||||
|
def compute_qkv(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
rope_emb: Optional[torch.Tensor] = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
q = self.q_proj(x)
|
||||||
|
context = x if context is None else context
|
||||||
|
k = self.k_proj(context)
|
||||||
|
v = self.v_proj(context)
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_norm_and_rotary_pos_emb(
|
||||||
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
v = self.v_norm(v)
|
||||||
|
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||||
|
q = apply_rotary_pos_emb(q, rope_emb)
|
||||||
|
k = apply_rotary_pos_emb(k, rope_emb)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
|
||||||
|
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||||
|
result = self.attn_op(q, k, v) # [B, S, H, D]
|
||||||
|
return self.output_dropout(self.output_proj(result))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
rope_emb: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): The query tensor of shape [B, Mq, K]
|
||||||
|
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||||
|
"""
|
||||||
|
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
||||||
|
return self.compute_attention(q, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
def __init__(self, num_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
|
||||||
|
def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
|
||||||
|
timesteps = timesteps_B_T.flatten().float()
|
||||||
|
half_dim = self.num_channels // 2
|
||||||
|
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||||
|
exponent = exponent / (half_dim - 0.0)
|
||||||
|
|
||||||
|
emb = torch.exp(exponent)
|
||||||
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
|
|
||||||
|
sin_emb = torch.sin(emb)
|
||||||
|
cos_emb = torch.cos(emb)
|
||||||
|
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
||||||
|
|
||||||
|
return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
logging.debug(
|
||||||
|
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
||||||
|
)
|
||||||
|
self.in_dim = in_features
|
||||||
|
self.out_dim = out_features
|
||||||
|
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)
|
||||||
|
self.activation = nn.SiLU()
|
||||||
|
self.use_adaln_lora = use_adaln_lora
|
||||||
|
if use_adaln_lora:
|
||||||
|
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
emb = self.linear_1(sample)
|
||||||
|
emb = self.activation(emb)
|
||||||
|
emb = self.linear_2(emb)
|
||||||
|
|
||||||
|
if self.use_adaln_lora:
|
||||||
|
adaln_lora_B_T_3D = emb
|
||||||
|
emb_B_T_D = sample
|
||||||
|
else:
|
||||||
|
adaln_lora_B_T_3D = None
|
||||||
|
emb_B_T_D = emb
|
||||||
|
|
||||||
|
return emb_B_T_D, adaln_lora_B_T_3D
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
"""
|
||||||
|
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
|
||||||
|
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
|
||||||
|
making it suitable for video and image processing tasks. It supports dividing the input into patches
|
||||||
|
and embedding each patch into a vector of size `out_channels`.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- spatial_patch_size (int): The size of each spatial patch.
|
||||||
|
- temporal_patch_size (int): The size of each temporal patch.
|
||||||
|
- in_channels (int): Number of input channels. Default: 3.
|
||||||
|
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
|
||||||
|
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
spatial_patch_size: int,
|
||||||
|
temporal_patch_size: int,
|
||||||
|
in_channels: int = 3,
|
||||||
|
out_channels: int = 768,
|
||||||
|
device=None, dtype=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.spatial_patch_size = spatial_patch_size
|
||||||
|
self.temporal_patch_size = temporal_patch_size
|
||||||
|
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
Rearrange(
|
||||||
|
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
|
||||||
|
r=temporal_patch_size,
|
||||||
|
m=spatial_patch_size,
|
||||||
|
n=spatial_patch_size,
|
||||||
|
),
|
||||||
|
operations.Linear(
|
||||||
|
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass of the PatchEmbed module.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
|
||||||
|
B is the batch size,
|
||||||
|
C is the number of channels,
|
||||||
|
T is the temporal dimension,
|
||||||
|
H is the height, and
|
||||||
|
W is the width of the input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
|
||||||
|
"""
|
||||||
|
assert x.dim() == 5
|
||||||
|
_, _, T, H, W = x.shape
|
||||||
|
assert (
|
||||||
|
H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
|
||||||
|
), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
|
||||||
|
assert T % self.temporal_patch_size == 0
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of video DiT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
spatial_patch_size: int,
|
||||||
|
temporal_patch_size: int,
|
||||||
|
out_channels: int,
|
||||||
|
use_adaln_lora: bool = False,
|
||||||
|
adaln_lora_dim: int = 256,
|
||||||
|
device=None, dtype=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.linear = operations.Linear(
|
||||||
|
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.n_adaln_chunks = 2
|
||||||
|
self.use_adaln_lora = use_adaln_lora
|
||||||
|
self.adaln_lora_dim = adaln_lora_dim
|
||||||
|
if use_adaln_lora:
|
||||||
|
self.adaln_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||||
|
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.adaln_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x_B_T_H_W_D: torch.Tensor,
|
||||||
|
emb_B_T_D: torch.Tensor,
|
||||||
|
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
if self.use_adaln_lora:
|
||||||
|
assert adaln_lora_B_T_3D is not None
|
||||||
|
shift_B_T_D, scale_B_T_D = (
|
||||||
|
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
|
||||||
|
).chunk(2, dim=-1)
|
||||||
|
else:
|
||||||
|
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
|
||||||
|
|
||||||
|
shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange(
|
||||||
|
scale_B_T_D, "b t d -> b t 1 1 d"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fn(
|
||||||
|
_x_B_T_H_W_D: torch.Tensor,
|
||||||
|
_norm_layer: nn.Module,
|
||||||
|
_scale_B_T_1_1_D: torch.Tensor,
|
||||||
|
_shift_B_T_1_1_D: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
||||||
|
|
||||||
|
x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)
|
||||||
|
x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
|
||||||
|
return x_B_T_H_W_O
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
"""
|
||||||
|
A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.
|
||||||
|
Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
x_dim (int): Dimension of input features
|
||||||
|
context_dim (int): Dimension of context features for cross-attention
|
||||||
|
num_heads (int): Number of attention heads
|
||||||
|
mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0
|
||||||
|
use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False
|
||||||
|
adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256
|
||||||
|
|
||||||
|
The block applies the following sequence:
|
||||||
|
1. Self-attention with AdaLN modulation
|
||||||
|
2. Cross-attention with AdaLN modulation
|
||||||
|
3. MLP with AdaLN modulation
|
||||||
|
|
||||||
|
Each component uses skip connections and layer normalization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
x_dim: int,
|
||||||
|
context_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
use_adaln_lora: bool = False,
|
||||||
|
adaln_lora_dim: int = 256,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.x_dim = x_dim
|
||||||
|
self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
self.cross_attn = Attention(
|
||||||
|
x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
self.use_adaln_lora = use_adaln_lora
|
||||||
|
if self.use_adaln_lora:
|
||||||
|
self.adaln_modulation_self_attn = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||||
|
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
self.adaln_modulation_cross_attn = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||||
|
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
self.adaln_modulation_mlp = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
||||||
|
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
||||||
|
self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
||||||
|
self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x_B_T_H_W_D: torch.Tensor,
|
||||||
|
emb_B_T_D: torch.Tensor,
|
||||||
|
crossattn_emb: torch.Tensor,
|
||||||
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
|
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||||
|
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if extra_per_block_pos_emb is not None:
|
||||||
|
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||||
|
|
||||||
|
if self.use_adaln_lora:
|
||||||
|
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
|
||||||
|
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||||
|
).chunk(3, dim=-1)
|
||||||
|
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
|
||||||
|
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||||
|
).chunk(3, dim=-1)
|
||||||
|
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
|
||||||
|
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
|
||||||
|
).chunk(3, dim=-1)
|
||||||
|
else:
|
||||||
|
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
|
||||||
|
emb_B_T_D
|
||||||
|
).chunk(3, dim=-1)
|
||||||
|
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
|
||||||
|
emb_B_T_D
|
||||||
|
).chunk(3, dim=-1)
|
||||||
|
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
|
||||||
|
|
||||||
|
# Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting
|
||||||
|
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||||
|
scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||||
|
gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||||
|
|
||||||
|
shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||||
|
scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||||
|
gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
||||||
|
|
||||||
|
shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
|
||||||
|
scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
|
||||||
|
gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
|
||||||
|
|
||||||
|
B, T, H, W, D = x_B_T_H_W_D.shape
|
||||||
|
|
||||||
|
def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
|
||||||
|
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
||||||
|
|
||||||
|
normalized_x_B_T_H_W_D = _fn(
|
||||||
|
x_B_T_H_W_D,
|
||||||
|
self.layer_norm_self_attn,
|
||||||
|
scale_self_attn_B_T_1_1_D,
|
||||||
|
shift_self_attn_B_T_1_1_D,
|
||||||
|
)
|
||||||
|
result_B_T_H_W_D = rearrange(
|
||||||
|
self.self_attn(
|
||||||
|
# normalized_x_B_T_HW_D,
|
||||||
|
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||||
|
None,
|
||||||
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
|
),
|
||||||
|
"b (t h w) d -> b t h w d",
|
||||||
|
t=T,
|
||||||
|
h=H,
|
||||||
|
w=W,
|
||||||
|
)
|
||||||
|
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
||||||
|
|
||||||
|
def _x_fn(
|
||||||
|
_x_B_T_H_W_D: torch.Tensor,
|
||||||
|
layer_norm_cross_attn: Callable,
|
||||||
|
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||||
|
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
_normalized_x_B_T_H_W_D = _fn(
|
||||||
|
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
||||||
|
)
|
||||||
|
_result_B_T_H_W_D = rearrange(
|
||||||
|
self.cross_attn(
|
||||||
|
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||||
|
crossattn_emb,
|
||||||
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
|
),
|
||||||
|
"b (t h w) d -> b t h w d",
|
||||||
|
t=T,
|
||||||
|
h=H,
|
||||||
|
w=W,
|
||||||
|
)
|
||||||
|
return _result_B_T_H_W_D
|
||||||
|
|
||||||
|
result_B_T_H_W_D = _x_fn(
|
||||||
|
x_B_T_H_W_D,
|
||||||
|
self.layer_norm_cross_attn,
|
||||||
|
scale_cross_attn_B_T_1_1_D,
|
||||||
|
shift_cross_attn_B_T_1_1_D,
|
||||||
|
)
|
||||||
|
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||||
|
|
||||||
|
normalized_x_B_T_H_W_D = _fn(
|
||||||
|
x_B_T_H_W_D,
|
||||||
|
self.layer_norm_mlp,
|
||||||
|
scale_mlp_B_T_1_1_D,
|
||||||
|
shift_mlp_B_T_1_1_D,
|
||||||
|
)
|
||||||
|
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
||||||
|
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
||||||
|
return x_B_T_H_W_D
|
||||||
|
|
||||||
|
|
||||||
|
class MiniTrainDIT(nn.Module):
|
||||||
|
"""
|
||||||
|
A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1)
|
||||||
|
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_img_h (int): Maximum height of the input images.
|
||||||
|
max_img_w (int): Maximum width of the input images.
|
||||||
|
max_frames (int): Maximum number of frames in the video sequence.
|
||||||
|
in_channels (int): Number of input channels (e.g., RGB channels for color images).
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
patch_spatial (tuple): Spatial resolution of patches for input processing.
|
||||||
|
patch_temporal (int): Temporal resolution of patches for input processing.
|
||||||
|
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
|
||||||
|
model_channels (int): Base number of channels used throughout the model.
|
||||||
|
num_blocks (int): Number of transformer blocks.
|
||||||
|
num_heads (int): Number of heads in the multi-head attention layers.
|
||||||
|
mlp_ratio (float): Expansion ratio for MLP blocks.
|
||||||
|
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
|
||||||
|
pos_emb_cls (str): Type of positional embeddings.
|
||||||
|
pos_emb_learnable (bool): Whether positional embeddings are learnable.
|
||||||
|
pos_emb_interpolation (str): Method for interpolating positional embeddings.
|
||||||
|
min_fps (int): Minimum frames per second.
|
||||||
|
max_fps (int): Maximum frames per second.
|
||||||
|
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
|
||||||
|
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
|
||||||
|
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
|
||||||
|
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
|
||||||
|
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
|
||||||
|
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
|
||||||
|
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
|
||||||
|
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
|
||||||
|
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_img_h: int,
|
||||||
|
max_img_w: int,
|
||||||
|
max_frames: int,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
patch_spatial: int, # tuple,
|
||||||
|
patch_temporal: int,
|
||||||
|
concat_padding_mask: bool = True,
|
||||||
|
# attention settings
|
||||||
|
model_channels: int = 768,
|
||||||
|
num_blocks: int = 10,
|
||||||
|
num_heads: int = 16,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
# cross attention settings
|
||||||
|
crossattn_emb_channels: int = 1024,
|
||||||
|
# positional embedding settings
|
||||||
|
pos_emb_cls: str = "sincos",
|
||||||
|
pos_emb_learnable: bool = False,
|
||||||
|
pos_emb_interpolation: str = "crop",
|
||||||
|
min_fps: int = 1,
|
||||||
|
max_fps: int = 30,
|
||||||
|
use_adaln_lora: bool = False,
|
||||||
|
adaln_lora_dim: int = 256,
|
||||||
|
rope_h_extrapolation_ratio: float = 1.0,
|
||||||
|
rope_w_extrapolation_ratio: float = 1.0,
|
||||||
|
rope_t_extrapolation_ratio: float = 1.0,
|
||||||
|
extra_per_block_abs_pos_emb: bool = False,
|
||||||
|
extra_h_extrapolation_ratio: float = 1.0,
|
||||||
|
extra_w_extrapolation_ratio: float = 1.0,
|
||||||
|
extra_t_extrapolation_ratio: float = 1.0,
|
||||||
|
rope_enable_fps_modulation: bool = True,
|
||||||
|
image_model=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.max_img_h = max_img_h
|
||||||
|
self.max_img_w = max_img_w
|
||||||
|
self.max_frames = max_frames
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.patch_spatial = patch_spatial
|
||||||
|
self.patch_temporal = patch_temporal
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.model_channels = model_channels
|
||||||
|
self.concat_padding_mask = concat_padding_mask
|
||||||
|
# positional embedding settings
|
||||||
|
self.pos_emb_cls = pos_emb_cls
|
||||||
|
self.pos_emb_learnable = pos_emb_learnable
|
||||||
|
self.pos_emb_interpolation = pos_emb_interpolation
|
||||||
|
self.min_fps = min_fps
|
||||||
|
self.max_fps = max_fps
|
||||||
|
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
|
||||||
|
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
|
||||||
|
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
|
||||||
|
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
|
||||||
|
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
|
||||||
|
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
|
||||||
|
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
|
||||||
|
self.rope_enable_fps_modulation = rope_enable_fps_modulation
|
||||||
|
|
||||||
|
self.build_pos_embed(device=device, dtype=dtype)
|
||||||
|
self.use_adaln_lora = use_adaln_lora
|
||||||
|
self.adaln_lora_dim = adaln_lora_dim
|
||||||
|
self.t_embedder = nn.Sequential(
|
||||||
|
Timesteps(model_channels),
|
||||||
|
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),
|
||||||
|
)
|
||||||
|
|
||||||
|
in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
spatial_patch_size=patch_spatial,
|
||||||
|
temporal_patch_size=patch_temporal,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=model_channels,
|
||||||
|
device=device, dtype=dtype, operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Block(
|
||||||
|
x_dim=model_channels,
|
||||||
|
context_dim=crossattn_emb_channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
use_adaln_lora=use_adaln_lora,
|
||||||
|
adaln_lora_dim=adaln_lora_dim,
|
||||||
|
device=device, dtype=dtype, operations=operations,
|
||||||
|
)
|
||||||
|
for _ in range(num_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer(
|
||||||
|
hidden_size=self.model_channels,
|
||||||
|
spatial_patch_size=self.patch_spatial,
|
||||||
|
temporal_patch_size=self.patch_temporal,
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
use_adaln_lora=self.use_adaln_lora,
|
||||||
|
adaln_lora_dim=self.adaln_lora_dim,
|
||||||
|
device=device, dtype=dtype, operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def build_pos_embed(self, device=None, dtype=None) -> None:
|
||||||
|
if self.pos_emb_cls == "rope3d":
|
||||||
|
cls_type = VideoRopePosition3DEmb
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
||||||
|
|
||||||
|
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
||||||
|
kwargs = dict(
|
||||||
|
model_channels=self.model_channels,
|
||||||
|
len_h=self.max_img_h // self.patch_spatial,
|
||||||
|
len_w=self.max_img_w // self.patch_spatial,
|
||||||
|
len_t=self.max_frames // self.patch_temporal,
|
||||||
|
max_fps=self.max_fps,
|
||||||
|
min_fps=self.min_fps,
|
||||||
|
is_learnable=self.pos_emb_learnable,
|
||||||
|
interpolation=self.pos_emb_interpolation,
|
||||||
|
head_dim=self.model_channels // self.num_heads,
|
||||||
|
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
|
||||||
|
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
|
||||||
|
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
|
||||||
|
enable_fps_modulation=self.rope_enable_fps_modulation,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.pos_embedder = cls_type(
|
||||||
|
**kwargs, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.extra_per_block_abs_pos_emb:
|
||||||
|
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
|
||||||
|
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
||||||
|
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
||||||
|
kwargs["device"] = device
|
||||||
|
kwargs["dtype"] = dtype
|
||||||
|
self.extra_pos_embedder = LearnablePosEmbAxis(
|
||||||
|
**kwargs, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_embedded_sequence(
|
||||||
|
self,
|
||||||
|
x_B_C_T_H_W: torch.Tensor,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_B_C_T_H_W (torch.Tensor): video
|
||||||
|
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
|
||||||
|
If None, a default value (`self.base_fps`) will be used.
|
||||||
|
padding_mask (Optional[torch.Tensor]): current it is not used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
|
||||||
|
- An optional positional embedding tensor, returned only if the positional embedding class
|
||||||
|
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
|
||||||
|
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
|
||||||
|
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
|
||||||
|
the `self.pos_embedder` with the shape [T, H, W].
|
||||||
|
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
|
||||||
|
`self.pos_embedder` with the fps tensor.
|
||||||
|
- Otherwise, the positional embeddings are generated without considering fps.
|
||||||
|
"""
|
||||||
|
if self.concat_padding_mask:
|
||||||
|
if padding_mask is None:
|
||||||
|
padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
|
||||||
|
else:
|
||||||
|
padding_mask = transforms.functional.resize(
|
||||||
|
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||||
|
)
|
||||||
|
x_B_C_T_H_W = torch.cat(
|
||||||
|
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
|
||||||
|
)
|
||||||
|
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
||||||
|
|
||||||
|
if self.extra_per_block_abs_pos_emb:
|
||||||
|
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
|
||||||
|
else:
|
||||||
|
extra_pos_emb = None
|
||||||
|
|
||||||
|
if "rope" in self.pos_emb_cls.lower():
|
||||||
|
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
|
||||||
|
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
||||||
|
|
||||||
|
return x_B_T_H_W_D, None, extra_pos_emb
|
||||||
|
|
||||||
|
def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
|
||||||
|
x_B_C_Tt_Hp_Wp = rearrange(
|
||||||
|
x_B_T_H_W_M,
|
||||||
|
"B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
|
||||||
|
p1=self.patch_spatial,
|
||||||
|
p2=self.patch_spatial,
|
||||||
|
t=self.patch_temporal,
|
||||||
|
)
|
||||||
|
return x_B_C_Tt_Hp_Wp
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
fps: Optional[torch.Tensor] = None,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
x_B_C_T_H_W = x
|
||||||
|
timesteps_B_T = timesteps
|
||||||
|
crossattn_emb = context
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
||||||
|
timesteps: (B, ) tensor of timesteps
|
||||||
|
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
||||||
|
"""
|
||||||
|
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
|
||||||
|
x_B_C_T_H_W,
|
||||||
|
fps=fps,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
if timesteps_B_T.ndim == 1:
|
||||||
|
timesteps_B_T = timesteps_B_T.unsqueeze(1)
|
||||||
|
t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))
|
||||||
|
t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
|
||||||
|
|
||||||
|
# for logging purpose
|
||||||
|
affline_scale_log_info = {}
|
||||||
|
affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach()
|
||||||
|
self.affline_scale_log_info = affline_scale_log_info
|
||||||
|
self.affline_emb = t_embedding_B_T_D
|
||||||
|
self.crossattn_emb = crossattn_emb
|
||||||
|
|
||||||
|
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||||
|
assert (
|
||||||
|
x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||||
|
), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}"
|
||||||
|
|
||||||
|
block_kwargs = {
|
||||||
|
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
||||||
|
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
||||||
|
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||||
|
}
|
||||||
|
for block in self.blocks:
|
||||||
|
x_B_T_H_W_D = block(
|
||||||
|
x_B_T_H_W_D,
|
||||||
|
t_embedding_B_T_D,
|
||||||
|
crossattn_emb,
|
||||||
|
**block_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||||
|
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
||||||
|
return x_B_C_Tt_Hp_Wp
|
@ -121,6 +121,9 @@ class ControlNetFlux(Flux):
|
|||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
if y is None:
|
||||||
|
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||||
|
|
||||||
# running on sequences img
|
# running on sequences img
|
||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
|
|
||||||
@ -174,7 +177,7 @@ class ControlNetFlux(Flux):
|
|||||||
out["output"] = out_output[:self.main_model_single]
|
out["output"] = out_output[:self.main_model_single]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
def forward(self, x, timesteps, context, y=None, guidance=None, hint=None, **kwargs):
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
if self.latent_input:
|
if self.latent_input:
|
||||||
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||||
|
@ -101,6 +101,10 @@ class Flux(nn.Module):
|
|||||||
transformer_options={},
|
transformer_options={},
|
||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
|
||||||
|
if y is None:
|
||||||
|
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
@ -155,6 +159,9 @@ class Flux(nn.Module):
|
|||||||
if add is not None:
|
if add is not None:
|
||||||
img += add
|
img += add
|
||||||
|
|
||||||
|
if img.dtype == torch.float16:
|
||||||
|
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
|
||||||
img = torch.cat((txt, img), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
@ -188,7 +195,7 @@ class Flux(nn.Module):
|
|||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y=None, guidance=None, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
@ -261,8 +261,8 @@ class CrossAttention(nn.Module):
|
|||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
|
|
||||||
self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
|
||||||
self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
@ -753,7 +753,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
n = p(n, extra_options)
|
n = p(n, extra_options)
|
||||||
|
|
||||||
x += n
|
x = n + x
|
||||||
if "middle_patch" in transformer_patches:
|
if "middle_patch" in transformer_patches:
|
||||||
patch = transformer_patches["middle_patch"]
|
patch = transformer_patches["middle_patch"]
|
||||||
for p in patch:
|
for p in patch:
|
||||||
@ -793,12 +793,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
n = p(n, extra_options)
|
n = p(n, extra_options)
|
||||||
|
|
||||||
x += n
|
x = n + x
|
||||||
if self.is_res:
|
if self.is_res:
|
||||||
x_skip = x
|
x_skip = x
|
||||||
x = self.ff(self.norm3(x))
|
x = self.ff(self.norm3(x))
|
||||||
if self.is_res:
|
if self.is_res:
|
||||||
x += x_skip
|
x = x_skip + x
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ import comfy.ldm.flux.model
|
|||||||
import comfy.ldm.lightricks.model
|
import comfy.ldm.lightricks.model
|
||||||
import comfy.ldm.hunyuan_video.model
|
import comfy.ldm.hunyuan_video.model
|
||||||
import comfy.ldm.cosmos.model
|
import comfy.ldm.cosmos.model
|
||||||
|
import comfy.ldm.cosmos.predict2
|
||||||
import comfy.ldm.lumina.model
|
import comfy.ldm.lumina.model
|
||||||
import comfy.ldm.wan.model
|
import comfy.ldm.wan.model
|
||||||
import comfy.ldm.hunyuan3d.model
|
import comfy.ldm.hunyuan3d.model
|
||||||
@ -48,6 +49,7 @@ import comfy.ops
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from . import utils
|
from . import utils
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
|
import comfy.model_sampling
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -63,38 +65,39 @@ class ModelType(Enum):
|
|||||||
V_PREDICTION_CONTINUOUS = 7
|
V_PREDICTION_CONTINUOUS = 7
|
||||||
FLUX = 8
|
FLUX = 8
|
||||||
IMG_TO_IMG = 9
|
IMG_TO_IMG = 9
|
||||||
|
FLOW_COSMOS = 10
|
||||||
|
|
||||||
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
|
|
||||||
|
|
||||||
|
|
||||||
def model_sampling(model_config, model_type):
|
def model_sampling(model_config, model_type):
|
||||||
s = ModelSamplingDiscrete
|
s = comfy.model_sampling.ModelSamplingDiscrete
|
||||||
|
|
||||||
if model_type == ModelType.EPS:
|
if model_type == ModelType.EPS:
|
||||||
c = EPS
|
c = comfy.model_sampling.EPS
|
||||||
elif model_type == ModelType.V_PREDICTION:
|
elif model_type == ModelType.V_PREDICTION:
|
||||||
c = V_PREDICTION
|
c = comfy.model_sampling.V_PREDICTION
|
||||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||||
c = V_PREDICTION
|
c = comfy.model_sampling.V_PREDICTION
|
||||||
s = ModelSamplingContinuousEDM
|
s = comfy.model_sampling.ModelSamplingContinuousEDM
|
||||||
elif model_type == ModelType.FLOW:
|
elif model_type == ModelType.FLOW:
|
||||||
c = comfy.model_sampling.CONST
|
c = comfy.model_sampling.CONST
|
||||||
s = comfy.model_sampling.ModelSamplingDiscreteFlow
|
s = comfy.model_sampling.ModelSamplingDiscreteFlow
|
||||||
elif model_type == ModelType.STABLE_CASCADE:
|
elif model_type == ModelType.STABLE_CASCADE:
|
||||||
c = EPS
|
c = comfy.model_sampling.EPS
|
||||||
s = StableCascadeSampling
|
s = comfy.model_sampling.StableCascadeSampling
|
||||||
elif model_type == ModelType.EDM:
|
elif model_type == ModelType.EDM:
|
||||||
c = EDM
|
c = comfy.model_sampling.EDM
|
||||||
s = ModelSamplingContinuousEDM
|
s = comfy.model_sampling.ModelSamplingContinuousEDM
|
||||||
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
|
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
|
||||||
c = V_PREDICTION
|
c = comfy.model_sampling.V_PREDICTION
|
||||||
s = ModelSamplingContinuousV
|
s = comfy.model_sampling.ModelSamplingContinuousV
|
||||||
elif model_type == ModelType.FLUX:
|
elif model_type == ModelType.FLUX:
|
||||||
c = comfy.model_sampling.CONST
|
c = comfy.model_sampling.CONST
|
||||||
s = comfy.model_sampling.ModelSamplingFlux
|
s = comfy.model_sampling.ModelSamplingFlux
|
||||||
elif model_type == ModelType.IMG_TO_IMG:
|
elif model_type == ModelType.IMG_TO_IMG:
|
||||||
c = comfy.model_sampling.IMG_TO_IMG
|
c = comfy.model_sampling.IMG_TO_IMG
|
||||||
|
elif model_type == ModelType.FLOW_COSMOS:
|
||||||
|
c = comfy.model_sampling.COSMOS_RFLOW
|
||||||
|
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||||
|
|
||||||
class ModelSampling(s, c):
|
class ModelSampling(s, c):
|
||||||
pass
|
pass
|
||||||
@ -102,6 +105,13 @@ def model_sampling(model_config, model_type):
|
|||||||
return ModelSampling(model_config)
|
return ModelSampling(model_config)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tensor(extra, dtype):
|
||||||
|
if hasattr(extra, "dtype"):
|
||||||
|
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||||
|
extra = extra.to(dtype)
|
||||||
|
return extra
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
class BaseModel(torch.nn.Module):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -165,13 +175,13 @@ class BaseModel(torch.nn.Module):
|
|||||||
extra_conds = {}
|
extra_conds = {}
|
||||||
for o in kwargs:
|
for o in kwargs:
|
||||||
extra = kwargs[o]
|
extra = kwargs[o]
|
||||||
|
|
||||||
if hasattr(extra, "dtype"):
|
if hasattr(extra, "dtype"):
|
||||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
extra = convert_tensor(extra, dtype)
|
||||||
extra = extra.to(dtype)
|
elif isinstance(extra, list):
|
||||||
if isinstance(extra, list):
|
|
||||||
ex = []
|
ex = []
|
||||||
for ext in extra:
|
for ext in extra:
|
||||||
ex.append(ext.to(dtype))
|
ex.append(convert_tensor(ext, dtype))
|
||||||
extra = ex
|
extra = ex
|
||||||
extra_conds[o] = extra
|
extra_conds[o] = extra
|
||||||
|
|
||||||
@ -991,6 +1001,43 @@ class CosmosVideo(BaseModel):
|
|||||||
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
|
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
|
||||||
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
|
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
|
||||||
|
|
||||||
|
class CosmosPredict2(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW_COSMOS, image_to_video=False, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.predict2.MiniTrainDIT)
|
||||||
|
self.image_to_video = image_to_video
|
||||||
|
if self.image_to_video:
|
||||||
|
self.concat_keys = ("mask_inverted",)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
|
if denoise_mask is not None:
|
||||||
|
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
|
||||||
|
|
||||||
|
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||||
|
if denoise_mask is None:
|
||||||
|
return timestep
|
||||||
|
condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True)
|
||||||
|
c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1
|
||||||
|
out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4])
|
||||||
|
return out
|
||||||
|
|
||||||
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
|
sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1))
|
||||||
|
sigma_noise_augmentation = 0 #TODO
|
||||||
|
if sigma_noise_augmentation != 0:
|
||||||
|
latent_image = latent_image + noise
|
||||||
|
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
|
||||||
|
sigma = (sigma / (sigma + 1))
|
||||||
|
return latent_image / (1.0 - sigma)
|
||||||
|
|
||||||
class Lumina2(BaseModel):
|
class Lumina2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
|
||||||
|
@ -407,6 +407,58 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["text_emb_dim"] = 2048
|
dit_config["text_emb_dim"] = 2048
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "cosmos_predict2"
|
||||||
|
dit_config["max_img_h"] = 240
|
||||||
|
dit_config["max_img_w"] = 240
|
||||||
|
dit_config["max_frames"] = 128
|
||||||
|
concat_padding_mask = True
|
||||||
|
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
|
||||||
|
dit_config["out_channels"] = 16
|
||||||
|
dit_config["patch_spatial"] = 2
|
||||||
|
dit_config["patch_temporal"] = 1
|
||||||
|
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
|
||||||
|
dit_config["concat_padding_mask"] = concat_padding_mask
|
||||||
|
dit_config["crossattn_emb_channels"] = 1024
|
||||||
|
dit_config["pos_emb_cls"] = "rope3d"
|
||||||
|
dit_config["pos_emb_learnable"] = True
|
||||||
|
dit_config["pos_emb_interpolation"] = "crop"
|
||||||
|
dit_config["min_fps"] = 1
|
||||||
|
dit_config["max_fps"] = 30
|
||||||
|
|
||||||
|
dit_config["use_adaln_lora"] = True
|
||||||
|
dit_config["adaln_lora_dim"] = 256
|
||||||
|
if dit_config["model_channels"] == 2048:
|
||||||
|
dit_config["num_blocks"] = 28
|
||||||
|
dit_config["num_heads"] = 16
|
||||||
|
elif dit_config["model_channels"] == 5120:
|
||||||
|
dit_config["num_blocks"] = 36
|
||||||
|
dit_config["num_heads"] = 40
|
||||||
|
|
||||||
|
if dit_config["in_channels"] == 16:
|
||||||
|
dit_config["extra_per_block_abs_pos_emb"] = False
|
||||||
|
dit_config["rope_h_extrapolation_ratio"] = 4.0
|
||||||
|
dit_config["rope_w_extrapolation_ratio"] = 4.0
|
||||||
|
dit_config["rope_t_extrapolation_ratio"] = 1.0
|
||||||
|
elif dit_config["in_channels"] == 17: # img to video
|
||||||
|
if dit_config["model_channels"] == 2048:
|
||||||
|
dit_config["extra_per_block_abs_pos_emb"] = False
|
||||||
|
dit_config["rope_h_extrapolation_ratio"] = 3.0
|
||||||
|
dit_config["rope_w_extrapolation_ratio"] = 3.0
|
||||||
|
dit_config["rope_t_extrapolation_ratio"] = 1.0
|
||||||
|
elif dit_config["model_channels"] == 5120:
|
||||||
|
dit_config["rope_h_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["rope_w_extrapolation_ratio"] = 2.0
|
||||||
|
dit_config["rope_t_extrapolation_ratio"] = 0.8333333333333334
|
||||||
|
|
||||||
|
dit_config["extra_h_extrapolation_ratio"] = 1.0
|
||||||
|
dit_config["extra_w_extrapolation_ratio"] = 1.0
|
||||||
|
dit_config["extra_t_extrapolation_ratio"] = 1.0
|
||||||
|
dit_config["rope_enable_fps_modulation"] = False
|
||||||
|
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -319,6 +319,7 @@ except:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
||||||
try:
|
try:
|
||||||
if is_amd():
|
if is_amd():
|
||||||
try:
|
try:
|
||||||
@ -329,9 +330,13 @@ try:
|
|||||||
logging.info("AMD arch: {}".format(arch))
|
logging.info("AMD arch: {}".format(arch))
|
||||||
logging.info("ROCm version: {}".format(rocm_version))
|
logging.info("ROCm version: {}".format(rocm_version))
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
|
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||||
if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches
|
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
|
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||||
|
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||||
|
SUPPORT_FP8_OPS = True
|
||||||
|
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -352,7 +357,7 @@ except:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5:
|
if torch_version_numeric >= (2, 5):
|
||||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||||
except:
|
except:
|
||||||
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
|
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
|
||||||
@ -1075,7 +1080,7 @@ def pytorch_attention_flash_attention():
|
|||||||
global ENABLE_PYTORCH_ATTENTION
|
global ENABLE_PYTORCH_ATTENTION
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
#TODO: more reliable way of checking for flash attention?
|
#TODO: more reliable way of checking for flash attention?
|
||||||
if is_nvidia(): #pytorch flash attention only works on Nvidia
|
if is_nvidia():
|
||||||
return True
|
return True
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return True
|
return True
|
||||||
@ -1091,7 +1096,7 @@ def force_upcast_attention_dtype():
|
|||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
|
|
||||||
macos_version = mac_version()
|
macos_version = mac_version()
|
||||||
if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS
|
if macos_version is not None and ((14, 5) <= macos_version): # black image bug on recent versions of macOS, I don't think it's ever getting fixed
|
||||||
upcast = True
|
upcast = True
|
||||||
|
|
||||||
if upcast:
|
if upcast:
|
||||||
@ -1290,7 +1295,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def supports_fp8_compute(device=None):
|
def supports_fp8_compute(device=None):
|
||||||
if args.supports_fp8_compute:
|
if SUPPORT_FP8_OPS:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if not is_nvidia():
|
if not is_nvidia():
|
||||||
@ -1304,11 +1309,11 @@ def supports_fp8_compute(device=None):
|
|||||||
if props.minor < 9:
|
if props.minor < 9:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3):
|
if torch_version_numeric < (2, 3):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if WINDOWS:
|
if WINDOWS:
|
||||||
if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4):
|
if torch_version_numeric < (2, 4):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
@ -17,23 +17,26 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Optional, Callable
|
|
||||||
import torch
|
import collections
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
import collections
|
|
||||||
import math
|
import math
|
||||||
|
import uuid
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
import comfy.utils
|
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.model_management
|
|
||||||
import comfy.lora
|
|
||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
|
import comfy.lora
|
||||||
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
import comfy.utils
|
||||||
from comfy.comfy_types import UnetWrapperFunction
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
|
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||||
|
|
||||||
|
|
||||||
def string_to_seed(data):
|
def string_to_seed(data):
|
||||||
crc = 0xFFFFFFFF
|
crc = 0xFFFFFFFF
|
||||||
|
@ -77,6 +77,25 @@ class IMG_TO_IMG(X0):
|
|||||||
def calculate_input(self, sigma, noise):
|
def calculate_input(self, sigma, noise):
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
class COSMOS_RFLOW:
|
||||||
|
def calculate_input(self, sigma, noise):
|
||||||
|
sigma = (sigma / (sigma + 1))
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||||
|
return noise * (1.0 - sigma)
|
||||||
|
|
||||||
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
|
sigma = (sigma / (sigma + 1))
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
|
return model_input * (1.0 - sigma) - model_output * sigma
|
||||||
|
|
||||||
|
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||||
|
noise = noise * sigma
|
||||||
|
noise += latent_image
|
||||||
|
return noise
|
||||||
|
|
||||||
|
def inverse_noise_scaling(self, sigma, latent):
|
||||||
|
return latent
|
||||||
|
|
||||||
class ModelSamplingDiscrete(torch.nn.Module):
|
class ModelSamplingDiscrete(torch.nn.Module):
|
||||||
def __init__(self, model_config=None, zsnr=None):
|
def __init__(self, model_config=None, zsnr=None):
|
||||||
@ -350,3 +369,15 @@ class ModelSamplingFlux(torch.nn.Module):
|
|||||||
if percent >= 1.0:
|
if percent >= 1.0:
|
||||||
return 0.0
|
return 0.0
|
||||||
return flux_time_shift(self.shift, 1.0, 1.0 - percent)
|
return flux_time_shift(self.shift, 1.0, 1.0 - percent)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSamplingCosmosRFlow(ModelSamplingContinuousEDM):
|
||||||
|
def timestep(self, sigma):
|
||||||
|
return sigma / (sigma + 1)
|
||||||
|
|
||||||
|
def sigma(self, timestep):
|
||||||
|
sigma_max = self.sigma_max
|
||||||
|
if timestep >= (sigma_max / (sigma_max + 1)):
|
||||||
|
return sigma_max
|
||||||
|
|
||||||
|
return timestep / (1 - timestep)
|
||||||
|
23
comfy/sd.py
23
comfy/sd.py
@ -1081,7 +1081,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
|
def load_diffusion_model_state_dict(sd, model_options={}):
|
||||||
|
"""
|
||||||
|
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sd (dict): State dictionary containing model weights and configuration
|
||||||
|
model_options (dict, optional): Additional options for model loading. Supports:
|
||||||
|
- dtype: Override model data type
|
||||||
|
- custom_operations: Custom model operations
|
||||||
|
- fp8_optimizations: Enable FP8 optimizations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelPatcher: A wrapped model instance that handles device management and weight loading.
|
||||||
|
Returns None if the model configuration cannot be detected.
|
||||||
|
|
||||||
|
The function:
|
||||||
|
1. Detects and handles different model formats (regular, diffusers, mmdit)
|
||||||
|
2. Configures model dtype based on parameters and device capabilities
|
||||||
|
3. Handles weight conversion and device placement
|
||||||
|
4. Manages model optimization settings
|
||||||
|
5. Loads weights and returns a device-managed model instance
|
||||||
|
"""
|
||||||
dtype = model_options.get("dtype", None)
|
dtype = model_options.get("dtype", None)
|
||||||
|
|
||||||
#Allow loading unets from checkpoint files
|
#Allow loading unets from checkpoint files
|
||||||
|
@ -462,7 +462,7 @@ class SDTokenizer:
|
|||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||||
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
|
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
|
||||||
self.min_length = min_length
|
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
|
||||||
self.end_token = None
|
self.end_token = None
|
||||||
self.min_padding = min_padding
|
self.min_padding = min_padding
|
||||||
|
|
||||||
|
@ -908,6 +908,48 @@ class CosmosI2V(CosmosT2V):
|
|||||||
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
|
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class CosmosT2IPredict2(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "cosmos_predict2",
|
||||||
|
"in_channels": 16,
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"sigma_data": 1.0,
|
||||||
|
"sigma_max": 80.0,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.Wan21
|
||||||
|
|
||||||
|
memory_usage_factor = 1.0
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
def __init__(self, unet_config):
|
||||||
|
super().__init__(unet_config)
|
||||||
|
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.CosmosPredict2(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
|
||||||
|
|
||||||
|
class CosmosI2VPredict2(CosmosT2IPredict2):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "cosmos_predict2",
|
||||||
|
"in_channels": 17,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.CosmosPredict2(self, image_to_video=True, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class Lumina2(supported_models_base.BASE):
|
class Lumina2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "lumina2",
|
"image_model": "lumina2",
|
||||||
@ -1139,6 +1181,6 @@ class ACEStep(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
|
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .base import WeightAdapterBase
|
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
||||||
from .lora import LoRAAdapter
|
from .lora import LoRAAdapter
|
||||||
from .loha import LoHaAdapter
|
from .loha import LoHaAdapter
|
||||||
from .lokr import LoKrAdapter
|
from .lokr import LoKrAdapter
|
||||||
@ -15,3 +15,9 @@ adapters: list[type[WeightAdapterBase]] = [
|
|||||||
OFTAdapter,
|
OFTAdapter,
|
||||||
BOFTAdapter,
|
BOFTAdapter,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"WeightAdapterBase",
|
||||||
|
"WeightAdapterTrainBase",
|
||||||
|
"adapters"
|
||||||
|
] + [a.__name__ for a in adapters]
|
||||||
|
@ -12,12 +12,20 @@ class WeightAdapterBase:
|
|||||||
weights: list[torch.Tensor]
|
weights: list[torch.Tensor]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
|
def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def to_train(self) -> "WeightAdapterTrainBase":
|
def to_train(self) -> "WeightAdapterTrainBase":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_train(cls, weight, *args) -> "WeightAdapterTrainBase":
|
||||||
|
"""
|
||||||
|
weight: The original weight tensor to be modified.
|
||||||
|
*args: Additional arguments for configuration, such as rank, alpha etc.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def calculate_weight(
|
def calculate_weight(
|
||||||
self,
|
self,
|
||||||
weight,
|
weight,
|
||||||
@ -33,10 +41,22 @@ class WeightAdapterBase:
|
|||||||
|
|
||||||
|
|
||||||
class WeightAdapterTrainBase(nn.Module):
|
class WeightAdapterTrainBase(nn.Module):
|
||||||
|
# We follow the scheme of PR #7032
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# [TODO] Collaborate with LoRA training PR #7032
|
def __call__(self, w):
|
||||||
|
"""
|
||||||
|
w: The original weight tensor to be modified.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
raise NotImplementedError("passive_memory_usage is not implemented")
|
||||||
|
|
||||||
|
def move_to(self, device):
|
||||||
|
self.to(device)
|
||||||
|
return self.passive_memory_usage()
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||||
@ -102,3 +122,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
|||||||
padded_tensor[new_slices] = tensor[orig_slices]
|
padded_tensor[new_slices] = tensor[orig_slices]
|
||||||
|
|
||||||
return padded_tensor
|
return padded_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def tucker_weight_from_conv(up, down, mid):
|
||||||
|
up = up.reshape(up.size(0), up.size(1))
|
||||||
|
down = down.reshape(down.size(0), down.size(1))
|
||||||
|
return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down)
|
||||||
|
|
||||||
|
|
||||||
|
def tucker_weight(wa, wb, t):
|
||||||
|
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
|
||||||
|
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
|
||||||
|
@ -3,7 +3,56 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
|
from .base import (
|
||||||
|
WeightAdapterBase,
|
||||||
|
WeightAdapterTrainBase,
|
||||||
|
weight_decompose,
|
||||||
|
pad_tensor_to_shape,
|
||||||
|
tucker_weight_from_conv,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraDiff(WeightAdapterTrainBase):
|
||||||
|
def __init__(self, weights):
|
||||||
|
super().__init__()
|
||||||
|
mat1, mat2, alpha, mid, dora_scale, reshape = weights
|
||||||
|
out_dim, rank = mat1.shape[0], mat1.shape[1]
|
||||||
|
rank, in_dim = mat2.shape[0], mat2.shape[1]
|
||||||
|
if mid is not None:
|
||||||
|
convdim = mid.ndim - 2
|
||||||
|
layer = (
|
||||||
|
torch.nn.Conv1d,
|
||||||
|
torch.nn.Conv2d,
|
||||||
|
torch.nn.Conv3d
|
||||||
|
)[convdim]
|
||||||
|
else:
|
||||||
|
layer = torch.nn.Linear
|
||||||
|
self.lora_up = layer(rank, out_dim, bias=False)
|
||||||
|
self.lora_down = layer(in_dim, rank, bias=False)
|
||||||
|
self.lora_up.weight.data.copy_(mat1)
|
||||||
|
self.lora_down.weight.data.copy_(mat2)
|
||||||
|
if mid is not None:
|
||||||
|
self.lora_mid = layer(mid, rank, bias=False)
|
||||||
|
self.lora_mid.weight.data.copy_(mid)
|
||||||
|
else:
|
||||||
|
self.lora_mid = None
|
||||||
|
self.rank = rank
|
||||||
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||||
|
|
||||||
|
def __call__(self, w):
|
||||||
|
org_dtype = w.dtype
|
||||||
|
if self.lora_mid is None:
|
||||||
|
diff = self.lora_up.weight @ self.lora_down.weight
|
||||||
|
else:
|
||||||
|
diff = tucker_weight_from_conv(
|
||||||
|
self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight
|
||||||
|
)
|
||||||
|
scale = self.alpha / self.rank
|
||||||
|
weight = w + scale * diff.reshape(w.shape)
|
||||||
|
return weight.to(org_dtype)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||||
|
|
||||||
|
|
||||||
class LoRAAdapter(WeightAdapterBase):
|
class LoRAAdapter(WeightAdapterBase):
|
||||||
@ -13,6 +62,21 @@ class LoRAAdapter(WeightAdapterBase):
|
|||||||
self.loaded_keys = loaded_keys
|
self.loaded_keys = loaded_keys
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
|
out_dim = weight.shape[0]
|
||||||
|
in_dim = weight.shape[1:].numel()
|
||||||
|
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||||
|
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||||
|
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
|
||||||
|
torch.nn.init.constant_(mat2, 0.0)
|
||||||
|
return LoraDiff(
|
||||||
|
(mat1, mat2, alpha, None, None, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_train(self):
|
||||||
|
return LoraDiff(self.weights)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
|
@ -346,20 +346,6 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def VALIDATE_INPUTS(cls, aspect_ratio: str):
|
|
||||||
try:
|
|
||||||
validate_aspect_ratio(
|
|
||||||
aspect_ratio,
|
|
||||||
minimum_ratio=cls.MINIMUM_RATIO,
|
|
||||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
|
||||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
|
||||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
return str(e)
|
|
||||||
return True
|
|
||||||
|
|
||||||
RETURN_TYPES = (IO.IMAGE,)
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||||
FUNCTION = "api_call"
|
FUNCTION = "api_call"
|
||||||
@ -380,6 +366,13 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
|||||||
unique_id: Union[str, None] = None,
|
unique_id: Union[str, None] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
aspect_ratio = validate_aspect_ratio(
|
||||||
|
aspect_ratio,
|
||||||
|
minimum_ratio=self.MINIMUM_RATIO,
|
||||||
|
maximum_ratio=self.MAXIMUM_RATIO,
|
||||||
|
minimum_ratio_str=self.MINIMUM_RATIO_STR,
|
||||||
|
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
|
||||||
|
)
|
||||||
if input_image is None:
|
if input_image is None:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
operation = SynchronousOperation(
|
operation = SynchronousOperation(
|
||||||
@ -395,13 +388,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
|||||||
guidance=round(guidance, 1),
|
guidance=round(guidance, 1),
|
||||||
steps=steps,
|
steps=steps,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
aspect_ratio=validate_aspect_ratio(
|
aspect_ratio=aspect_ratio,
|
||||||
aspect_ratio,
|
|
||||||
minimum_ratio=self.MINIMUM_RATIO,
|
|
||||||
maximum_ratio=self.MAXIMUM_RATIO,
|
|
||||||
minimum_ratio_str=self.MINIMUM_RATIO_STR,
|
|
||||||
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
|
|
||||||
),
|
|
||||||
input_image=(
|
input_image=(
|
||||||
input_image
|
input_image
|
||||||
if input_image is None
|
if input_image is None
|
||||||
|
@ -324,7 +324,7 @@ class IdeogramV1(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.IMAGE,)
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
FUNCTION = "api_call"
|
FUNCTION = "api_call"
|
||||||
CATEGORY = "api node/image/Ideogram/v1"
|
CATEGORY = "api node/image/Ideogram"
|
||||||
DESCRIPTION = cleandoc(__doc__ or "")
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
API_NODE = True
|
API_NODE = True
|
||||||
|
|
||||||
@ -483,7 +483,7 @@ class IdeogramV2(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.IMAGE,)
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
FUNCTION = "api_call"
|
FUNCTION = "api_call"
|
||||||
CATEGORY = "api node/image/Ideogram/v2"
|
CATEGORY = "api node/image/Ideogram"
|
||||||
DESCRIPTION = cleandoc(__doc__ or "")
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
API_NODE = True
|
API_NODE = True
|
||||||
|
|
||||||
@ -649,7 +649,7 @@ class IdeogramV3(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.IMAGE,)
|
RETURN_TYPES = (IO.IMAGE,)
|
||||||
FUNCTION = "api_call"
|
FUNCTION = "api_call"
|
||||||
CATEGORY = "api node/image/Ideogram/v3"
|
CATEGORY = "api node/image/Ideogram"
|
||||||
DESCRIPTION = cleandoc(__doc__ or "")
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
API_NODE = True
|
API_NODE = True
|
||||||
|
|
||||||
|
97
comfy_config/config_parser.py
Normal file
97
comfy_config/config_parser.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic_settings import PydanticBaseSettingsSource, TomlConfigSettingsSource
|
||||||
|
|
||||||
|
from comfy_config.types import (
|
||||||
|
ComfyConfig,
|
||||||
|
ProjectConfig,
|
||||||
|
PyProjectConfig,
|
||||||
|
PyProjectSettings
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Extract configuration from a custom node directory's pyproject.toml file or a Python file.
|
||||||
|
|
||||||
|
This function reads and parses the pyproject.toml file in the specified directory
|
||||||
|
to extract project and ComfyUI-specific configuration information. If no
|
||||||
|
pyproject.toml file is found, it creates a minimal configuration using the
|
||||||
|
folder name as the project name. If a Python file is provided, it uses the
|
||||||
|
file name (without extension) as the project name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Path to the directory containing the pyproject.toml file, or
|
||||||
|
path to a .py file. If pyproject.toml doesn't exist in a directory,
|
||||||
|
the folder name will be used as the default project name. If a .py
|
||||||
|
file is provided, the filename (without .py extension) will be used
|
||||||
|
as the project name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[PyProjectConfig]: A PyProjectConfig object containing:
|
||||||
|
- project: Basic project information (name, version, dependencies, etc.)
|
||||||
|
- tool_comfy: ComfyUI-specific configuration (publisher_id, models, etc.)
|
||||||
|
Returns None if configuration extraction fails or if the provided file
|
||||||
|
is not a Python file.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- If pyproject.toml is missing in a directory, creates a default config with folder name
|
||||||
|
- If a .py file is provided, creates a default config with filename (without extension)
|
||||||
|
- Returns None for non-Python files
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from comfy_config import config_parser
|
||||||
|
>>> # For directory
|
||||||
|
>>> custom_node_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
>>> project_config = config_parser.extract_node_configuration(custom_node_dir)
|
||||||
|
>>> print(project_config.project.name) # "my_custom_node" or name from pyproject.toml
|
||||||
|
>>>
|
||||||
|
>>> # For single-file Python node file
|
||||||
|
>>> py_file_path = os.path.realpath(__file__) # "/path/to/my_node.py"
|
||||||
|
>>> project_config = config_parser.extract_node_configuration(py_file_path)
|
||||||
|
>>> print(project_config.project.name) # "my_node"
|
||||||
|
"""
|
||||||
|
def extract_node_configuration(path) -> Optional[PyProjectConfig]:
|
||||||
|
if os.path.isfile(path):
|
||||||
|
file_path = Path(path)
|
||||||
|
|
||||||
|
if file_path.suffix.lower() != '.py':
|
||||||
|
return None
|
||||||
|
|
||||||
|
project_name = file_path.stem
|
||||||
|
project = ProjectConfig(name=project_name)
|
||||||
|
comfy = ComfyConfig()
|
||||||
|
return PyProjectConfig(project=project, tool_comfy=comfy)
|
||||||
|
|
||||||
|
folder_name = os.path.basename(path)
|
||||||
|
toml_path = Path(path) / "pyproject.toml"
|
||||||
|
|
||||||
|
if not toml_path.exists():
|
||||||
|
project = ProjectConfig(name=folder_name)
|
||||||
|
comfy = ComfyConfig()
|
||||||
|
return PyProjectConfig(project=project, tool_comfy=comfy)
|
||||||
|
|
||||||
|
raw_settings = load_pyproject_settings(toml_path)
|
||||||
|
|
||||||
|
project_data = raw_settings.project
|
||||||
|
|
||||||
|
tool_data = raw_settings.tool
|
||||||
|
comfy_data = tool_data.get("comfy", {}) if tool_data else {}
|
||||||
|
|
||||||
|
return PyProjectConfig(project=project_data, tool_comfy=comfy_data)
|
||||||
|
|
||||||
|
|
||||||
|
def load_pyproject_settings(toml_path: Path) -> PyProjectSettings:
|
||||||
|
class PyProjectLoader(PyProjectSettings):
|
||||||
|
@classmethod
|
||||||
|
def settings_customise_sources(
|
||||||
|
cls,
|
||||||
|
settings_cls,
|
||||||
|
init_settings: PydanticBaseSettingsSource,
|
||||||
|
env_settings: PydanticBaseSettingsSource,
|
||||||
|
dotenv_settings: PydanticBaseSettingsSource,
|
||||||
|
file_secret_settings: PydanticBaseSettingsSource,
|
||||||
|
):
|
||||||
|
return (TomlConfigSettingsSource(settings_cls, toml_path),)
|
||||||
|
|
||||||
|
return PyProjectLoader()
|
93
comfy_config/types.py
Normal file
93
comfy_config/types.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
# IMPORTANT: The type definitions specified in pyproject.toml for custom nodes
|
||||||
|
# must remain synchronized with the corresponding files in the https://github.com/Comfy-Org/comfy-cli/blob/main/comfy_cli/registry/types.py.
|
||||||
|
# Any changes to one must be reflected in the other to maintain consistency.
|
||||||
|
|
||||||
|
class NodeVersion(BaseModel):
|
||||||
|
changelog: str
|
||||||
|
dependencies: List[str]
|
||||||
|
deprecated: bool
|
||||||
|
id: str
|
||||||
|
version: str
|
||||||
|
download_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class Node(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
author: Optional[str] = None
|
||||||
|
license: Optional[str] = None
|
||||||
|
icon: Optional[str] = None
|
||||||
|
repository: Optional[str] = None
|
||||||
|
tags: List[str] = Field(default_factory=list)
|
||||||
|
latest_version: Optional[NodeVersion] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PublishNodeVersionResponse(BaseModel):
|
||||||
|
node_version: NodeVersion
|
||||||
|
signedUrl: str
|
||||||
|
|
||||||
|
|
||||||
|
class URLs(BaseModel):
|
||||||
|
homepage: str = Field(default="", alias="Homepage")
|
||||||
|
documentation: str = Field(default="", alias="Documentation")
|
||||||
|
repository: str = Field(default="", alias="Repository")
|
||||||
|
issues: str = Field(default="", alias="Issues")
|
||||||
|
|
||||||
|
|
||||||
|
class Model(BaseModel):
|
||||||
|
location: str
|
||||||
|
model_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyConfig(BaseModel):
|
||||||
|
publisher_id: str = Field(default="", alias="PublisherId")
|
||||||
|
display_name: str = Field(default="", alias="DisplayName")
|
||||||
|
icon: str = Field(default="", alias="Icon")
|
||||||
|
models: List[Model] = Field(default_factory=list, alias="Models")
|
||||||
|
includes: List[str] = Field(default_factory=list)
|
||||||
|
web: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class License(BaseModel):
|
||||||
|
file: str = ""
|
||||||
|
text: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectConfig(BaseModel):
|
||||||
|
name: str = ""
|
||||||
|
description: str = ""
|
||||||
|
version: str = "1.0.0"
|
||||||
|
requires_python: str = Field(default=">= 3.9", alias="requires-python")
|
||||||
|
dependencies: List[str] = Field(default_factory=list)
|
||||||
|
license: License = Field(default_factory=License)
|
||||||
|
urls: URLs = Field(default_factory=URLs)
|
||||||
|
|
||||||
|
@field_validator('license', mode='before')
|
||||||
|
@classmethod
|
||||||
|
def validate_license(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
return License(text=v)
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
return License(**v)
|
||||||
|
elif isinstance(v, License):
|
||||||
|
return v
|
||||||
|
else:
|
||||||
|
return License()
|
||||||
|
|
||||||
|
|
||||||
|
class PyProjectConfig(BaseModel):
|
||||||
|
project: ProjectConfig = Field(default_factory=ProjectConfig)
|
||||||
|
tool_comfy: ComfyConfig = Field(default_factory=ComfyConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class PyProjectSettings(BaseSettings):
|
||||||
|
project: dict = Field(default_factory=dict)
|
||||||
|
|
||||||
|
tool: dict = Field(default_factory=dict)
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(extra='allow')
|
@ -2,6 +2,7 @@ import nodes
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.latent_formats
|
||||||
|
|
||||||
|
|
||||||
class EmptyCosmosLatentVideo:
|
class EmptyCosmosLatentVideo:
|
||||||
@ -75,8 +76,53 @@ class CosmosImageToVideoLatent:
|
|||||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||||
return (out_latent,)
|
return (out_latent,)
|
||||||
|
|
||||||
|
class CosmosPredict2ImageToVideoLatent:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"vae": ("VAE", ),
|
||||||
|
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||||
|
"length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
},
|
||||||
|
"optional": {"start_image": ("IMAGE", ),
|
||||||
|
"end_image": ("IMAGE", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/inpaint"
|
||||||
|
|
||||||
|
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
|
||||||
|
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
if start_image is None and end_image is None:
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return (out_latent,)
|
||||||
|
|
||||||
|
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
if start_image is not None:
|
||||||
|
latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
|
||||||
|
latent[:, :, :latent_temp.shape[-3]] = latent_temp
|
||||||
|
mask[:, :, :latent_temp.shape[-3]] *= 0.0
|
||||||
|
|
||||||
|
if end_image is not None:
|
||||||
|
latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
|
||||||
|
latent[:, :, -latent_temp.shape[-3]:] = latent_temp
|
||||||
|
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
latent_format = comfy.latent_formats.Wan21()
|
||||||
|
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
|
||||||
|
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
||||||
|
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||||
|
return (out_latent,)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
|
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
|
||||||
"CosmosImageToVideoLatent": CosmosImageToVideoLatent,
|
"CosmosImageToVideoLatent": CosmosImageToVideoLatent,
|
||||||
|
"CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent,
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,8 @@ from inspect import cleandoc
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
from comfy.comfy_types import FileLocator
|
from comfy.comfy_types import FileLocator, IO
|
||||||
|
from server import PromptServer
|
||||||
|
|
||||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||||
|
|
||||||
@ -491,6 +492,37 @@ class SaveSVGNode:
|
|||||||
counter += 1
|
counter += 1
|
||||||
return { "ui": { "images": results } }
|
return { "ui": { "images": results } }
|
||||||
|
|
||||||
|
class GetImageSize:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": (IO.IMAGE,),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.INT, IO.INT, IO.INT)
|
||||||
|
RETURN_NAMES = ("width", "height", "batch_size")
|
||||||
|
FUNCTION = "get_size"
|
||||||
|
|
||||||
|
CATEGORY = "image"
|
||||||
|
DESCRIPTION = """Returns width and height of the image, and passes it through unchanged."""
|
||||||
|
|
||||||
|
def get_size(self, image, unique_id=None) -> tuple[int, int]:
|
||||||
|
height = image.shape[1]
|
||||||
|
width = image.shape[2]
|
||||||
|
batch_size = image.shape[0]
|
||||||
|
|
||||||
|
# Send progress text to display size on the node
|
||||||
|
if unique_id:
|
||||||
|
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id)
|
||||||
|
|
||||||
|
return width, height, batch_size
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ImageCrop": ImageCrop,
|
"ImageCrop": ImageCrop,
|
||||||
"RepeatImageBatch": RepeatImageBatch,
|
"RepeatImageBatch": RepeatImageBatch,
|
||||||
@ -500,4 +532,5 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||||
"SaveSVGNode": SaveSVGNode,
|
"SaveSVGNode": SaveSVGNode,
|
||||||
"ImageStitch": ImageStitch,
|
"ImageStitch": ImageStitch,
|
||||||
|
"GetImageSize": GetImageSize,
|
||||||
}
|
}
|
||||||
|
@ -189,7 +189,7 @@ class ModelSamplingContinuousEDM:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "model": ("MODEL",),
|
return {"required": { "model": ("MODEL",),
|
||||||
"sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps"],),
|
"sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps", "cosmos_rflow"],),
|
||||||
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||||
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||||
}}
|
}}
|
||||||
@ -202,6 +202,7 @@ class ModelSamplingContinuousEDM:
|
|||||||
def patch(self, model, sampling, sigma_max, sigma_min):
|
def patch(self, model, sampling, sigma_max, sigma_min):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|
||||||
|
sampling_base = comfy.model_sampling.ModelSamplingContinuousEDM
|
||||||
latent_format = None
|
latent_format = None
|
||||||
sigma_data = 1.0
|
sigma_data = 1.0
|
||||||
if sampling == "eps":
|
if sampling == "eps":
|
||||||
@ -215,8 +216,11 @@ class ModelSamplingContinuousEDM:
|
|||||||
sampling_type = comfy.model_sampling.EDM
|
sampling_type = comfy.model_sampling.EDM
|
||||||
sigma_data = 0.5
|
sigma_data = 0.5
|
||||||
latent_format = comfy.latent_formats.SDXL_Playground_2_5()
|
latent_format = comfy.latent_formats.SDXL_Playground_2_5()
|
||||||
|
elif sampling == "cosmos_rflow":
|
||||||
|
sampling_type = comfy.model_sampling.COSMOS_RFLOW
|
||||||
|
sampling_base = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||||
|
|
||||||
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
|
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||||
|
709
comfy_extras/nodes_train.py
Normal file
709
comfy_extras/nodes_train.py
Normal file
@ -0,0 +1,709 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import safetensors
|
||||||
|
import torch
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
from PIL.PngImagePlugin import PngInfo
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
import comfy.samplers
|
||||||
|
import comfy.sd
|
||||||
|
import comfy.utils
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy_extras.nodes_custom_sampler
|
||||||
|
import folder_paths
|
||||||
|
import node_helpers
|
||||||
|
from comfy.cli_args import args
|
||||||
|
from comfy.comfy_types.node_typing import IO
|
||||||
|
from comfy.weight_adapter import adapters
|
||||||
|
|
||||||
|
|
||||||
|
class TrainSampler(comfy.samplers.Sampler):
|
||||||
|
|
||||||
|
def __init__(self, loss_fn, optimizer, loss_callback=None):
|
||||||
|
self.loss_fn = loss_fn
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.loss_callback = loss_callback
|
||||||
|
|
||||||
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False)
|
||||||
|
latent = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||||
|
torch.zeros_like(sigmas),
|
||||||
|
torch.zeros_like(noise, requires_grad=True),
|
||||||
|
latent_image,
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure model is in training mode and computing gradients
|
||||||
|
# x0 pred
|
||||||
|
denoised = model_wrap(noise, sigmas, **extra_args)
|
||||||
|
try:
|
||||||
|
loss = self.loss_fn(denoised, latent.clone())
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "does not require grad and does not have a grad_fn" in str(e):
|
||||||
|
logging.info("WARNING: This is likely due to the model is loaded in inference mode.")
|
||||||
|
loss.backward()
|
||||||
|
if self.loss_callback:
|
||||||
|
self.loss_callback(loss.item())
|
||||||
|
|
||||||
|
self.optimizer.step()
|
||||||
|
# torch.cuda.memory._dump_snapshot("trainn.pickle")
|
||||||
|
# torch.cuda.memory._record_memory_history(enabled=None)
|
||||||
|
return torch.zeros_like(latent_image)
|
||||||
|
|
||||||
|
|
||||||
|
class BiasDiff(torch.nn.Module):
|
||||||
|
def __init__(self, bias):
|
||||||
|
super().__init__()
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
def __call__(self, b):
|
||||||
|
org_dtype = b.dtype
|
||||||
|
return (b.to(self.bias) + self.bias).to(org_dtype)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
return self.bias.nelement() * self.bias.element_size()
|
||||||
|
|
||||||
|
def move_to(self, device):
|
||||||
|
self.to(device=device)
|
||||||
|
return self.passive_memory_usage()
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_process_images(image_files, input_dir, resize_method="None"):
|
||||||
|
"""Utility function to load and process a list of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_files: List of image filenames
|
||||||
|
input_dir: Base directory containing the images
|
||||||
|
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Batch of processed images
|
||||||
|
"""
|
||||||
|
if not image_files:
|
||||||
|
raise ValueError("No valid images found in input")
|
||||||
|
|
||||||
|
output_images = []
|
||||||
|
w, h = None, None
|
||||||
|
|
||||||
|
for file in image_files:
|
||||||
|
image_path = os.path.join(input_dir, file)
|
||||||
|
img = node_helpers.pillow(Image.open, image_path)
|
||||||
|
|
||||||
|
if img.mode == "I":
|
||||||
|
img = img.point(lambda i: i * (1 / 255))
|
||||||
|
img = img.convert("RGB")
|
||||||
|
|
||||||
|
if w is None and h is None:
|
||||||
|
w, h = img.size[0], img.size[1]
|
||||||
|
|
||||||
|
# Resize image to first image
|
||||||
|
if img.size[0] != w or img.size[1] != h:
|
||||||
|
if resize_method == "Stretch":
|
||||||
|
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||||
|
elif resize_method == "Crop":
|
||||||
|
img = img.crop((0, 0, w, h))
|
||||||
|
elif resize_method == "Pad":
|
||||||
|
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||||
|
elif resize_method == "None":
|
||||||
|
raise ValueError(
|
||||||
|
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
|
||||||
|
)
|
||||||
|
|
||||||
|
img_array = np.array(img).astype(np.float32) / 255.0
|
||||||
|
img_tensor = torch.from_numpy(img_array)[None,]
|
||||||
|
output_images.append(img_tensor)
|
||||||
|
|
||||||
|
return torch.cat(output_images, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadImageSetNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": (
|
||||||
|
[
|
||||||
|
f
|
||||||
|
for f in os.listdir(folder_paths.get_input_directory())
|
||||||
|
if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"))
|
||||||
|
],
|
||||||
|
{"image_upload": True, "allow_batch": True},
|
||||||
|
)
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"resize_method": (
|
||||||
|
["None", "Stretch", "Crop", "Pad"],
|
||||||
|
{"default": "None"},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
INPUT_IS_LIST = True
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "load_images"
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
DESCRIPTION = "Loads a batch of images from a directory for training."
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(s, images, resize_method):
|
||||||
|
filenames = images[0] if isinstance(images[0], list) else images
|
||||||
|
|
||||||
|
for image in filenames:
|
||||||
|
if not folder_paths.exists_annotated_filepath(image):
|
||||||
|
return "Invalid image file: {}".format(image)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def load_images(self, input_files, resize_method):
|
||||||
|
input_dir = folder_paths.get_input_directory()
|
||||||
|
valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"]
|
||||||
|
image_files = [
|
||||||
|
f
|
||||||
|
for f in input_files
|
||||||
|
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||||
|
]
|
||||||
|
output_tensor = load_and_process_images(image_files, input_dir, resize_method)
|
||||||
|
return (output_tensor,)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadImageSetFromFolderNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."})
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"resize_method": (
|
||||||
|
["None", "Stretch", "Crop", "Pad"],
|
||||||
|
{"default": "None"},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "load_images"
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
DESCRIPTION = "Loads a batch of images from a directory for training."
|
||||||
|
|
||||||
|
def load_images(self, folder, resize_method):
|
||||||
|
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||||
|
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
||||||
|
image_files = [
|
||||||
|
f
|
||||||
|
for f in os.listdir(sub_input_dir)
|
||||||
|
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||||
|
]
|
||||||
|
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method)
|
||||||
|
return (output_tensor,)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_loss_graph(loss_map, steps):
|
||||||
|
width, height = 500, 300
|
||||||
|
img = Image.new("RGB", (width, height), "white")
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
|
||||||
|
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()]
|
||||||
|
|
||||||
|
prev_point = (0, height - int(scaled_loss[0] * height))
|
||||||
|
for i, l in enumerate(scaled_loss[1:], start=1):
|
||||||
|
x = int(i / (steps - 1) * width)
|
||||||
|
y = height - int(l * height)
|
||||||
|
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||||
|
prev_point = (x, y)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None):
|
||||||
|
if result is None:
|
||||||
|
result = []
|
||||||
|
elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
|
||||||
|
result.append(model)
|
||||||
|
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
|
||||||
|
return result
|
||||||
|
name = name or "root"
|
||||||
|
for next_name, child in model.named_children():
|
||||||
|
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def patch(m):
|
||||||
|
if not hasattr(m, "forward"):
|
||||||
|
return
|
||||||
|
org_forward = m.forward
|
||||||
|
def fwd(args, kwargs):
|
||||||
|
return org_forward(*args, **kwargs)
|
||||||
|
def checkpointing_fwd(*args, **kwargs):
|
||||||
|
return torch.utils.checkpoint.checkpoint(
|
||||||
|
fwd, args, kwargs, use_reentrant=False
|
||||||
|
)
|
||||||
|
m.org_forward = org_forward
|
||||||
|
m.forward = checkpointing_fwd
|
||||||
|
|
||||||
|
|
||||||
|
def unpatch(m):
|
||||||
|
if hasattr(m, "org_forward"):
|
||||||
|
m.forward = m.org_forward
|
||||||
|
del m.org_forward
|
||||||
|
|
||||||
|
|
||||||
|
class TrainLoraNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}),
|
||||||
|
"latents": (
|
||||||
|
"LATENT",
|
||||||
|
{
|
||||||
|
"tooltip": "The Latents to use for training, serve as dataset/input of the model."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"positive": (
|
||||||
|
IO.CONDITIONING,
|
||||||
|
{"tooltip": "The positive conditioning to use for training."},
|
||||||
|
),
|
||||||
|
"batch_size": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 1,
|
||||||
|
"min": 1,
|
||||||
|
"max": 10000,
|
||||||
|
"step": 1,
|
||||||
|
"tooltip": "The batch size to use for training.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"steps": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 16,
|
||||||
|
"min": 1,
|
||||||
|
"max": 100000,
|
||||||
|
"tooltip": "The number of steps to train the LoRA for.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"learning_rate": (
|
||||||
|
IO.FLOAT,
|
||||||
|
{
|
||||||
|
"default": 0.0005,
|
||||||
|
"min": 0.0000001,
|
||||||
|
"max": 1.0,
|
||||||
|
"step": 0.000001,
|
||||||
|
"tooltip": "The learning rate to use for training.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"rank": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 8,
|
||||||
|
"min": 1,
|
||||||
|
"max": 128,
|
||||||
|
"tooltip": "The rank of the LoRA layers.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"optimizer": (
|
||||||
|
["AdamW", "Adam", "SGD", "RMSprop"],
|
||||||
|
{
|
||||||
|
"default": "AdamW",
|
||||||
|
"tooltip": "The optimizer to use for training.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"loss_function": (
|
||||||
|
["MSE", "L1", "Huber", "SmoothL1"],
|
||||||
|
{
|
||||||
|
"default": "MSE",
|
||||||
|
"tooltip": "The loss function to use for training.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"training_dtype": (
|
||||||
|
["bf16", "fp32"],
|
||||||
|
{"default": "bf16", "tooltip": "The dtype to use for training."},
|
||||||
|
),
|
||||||
|
"lora_dtype": (
|
||||||
|
["bf16", "fp32"],
|
||||||
|
{"default": "bf16", "tooltip": "The dtype to use for lora."},
|
||||||
|
),
|
||||||
|
"existing_lora": (
|
||||||
|
folder_paths.get_filename_list("loras") + ["[None]"],
|
||||||
|
{
|
||||||
|
"default": "[None]",
|
||||||
|
"tooltip": "The existing LoRA to append to. Set to None for new LoRA.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT)
|
||||||
|
RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps")
|
||||||
|
FUNCTION = "train"
|
||||||
|
CATEGORY = "training"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
latents,
|
||||||
|
positive,
|
||||||
|
batch_size,
|
||||||
|
steps,
|
||||||
|
learning_rate,
|
||||||
|
rank,
|
||||||
|
optimizer,
|
||||||
|
loss_function,
|
||||||
|
seed,
|
||||||
|
training_dtype,
|
||||||
|
lora_dtype,
|
||||||
|
existing_lora,
|
||||||
|
):
|
||||||
|
mp = model.clone()
|
||||||
|
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||||
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
|
mp.set_model_compute_dtype(dtype)
|
||||||
|
|
||||||
|
latents = latents["samples"].to(dtype)
|
||||||
|
num_images = latents.shape[0]
|
||||||
|
|
||||||
|
with torch.inference_mode(False):
|
||||||
|
lora_sd = {}
|
||||||
|
generator = torch.Generator()
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
|
||||||
|
# Load existing LoRA weights if provided
|
||||||
|
existing_weights = {}
|
||||||
|
existing_steps = 0
|
||||||
|
if existing_lora != "[None]":
|
||||||
|
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
||||||
|
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
||||||
|
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
||||||
|
if lora_path:
|
||||||
|
existing_weights = comfy.utils.load_torch_file(lora_path)
|
||||||
|
|
||||||
|
all_weight_adapters = []
|
||||||
|
for n, m in mp.model.named_modules():
|
||||||
|
if hasattr(m, "weight_function"):
|
||||||
|
if m.weight is not None:
|
||||||
|
key = "{}.weight".format(n)
|
||||||
|
shape = m.weight.shape
|
||||||
|
if len(shape) >= 2:
|
||||||
|
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||||
|
dora_scale = existing_weights.get(
|
||||||
|
f"{key}.dora_scale", None
|
||||||
|
)
|
||||||
|
for adapter_cls in adapters:
|
||||||
|
existing_adapter = adapter_cls.load(
|
||||||
|
n, existing_weights, alpha, dora_scale
|
||||||
|
)
|
||||||
|
if existing_adapter is not None:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# If no existing adapter found, use LoRA
|
||||||
|
# We will add algo option in the future
|
||||||
|
existing_adapter = None
|
||||||
|
adapter_cls = adapters[0]
|
||||||
|
|
||||||
|
if existing_adapter is not None:
|
||||||
|
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||||
|
else:
|
||||||
|
# Use LoRA with alpha=1.0 by default
|
||||||
|
train_adapter = adapter_cls.create_train(
|
||||||
|
m.weight, rank=rank, alpha=1.0
|
||||||
|
).to(lora_dtype)
|
||||||
|
for name, parameter in train_adapter.named_parameters():
|
||||||
|
lora_sd[f"{n}.{name}"] = parameter
|
||||||
|
|
||||||
|
mp.add_weight_wrapper(key, train_adapter)
|
||||||
|
all_weight_adapters.append(train_adapter)
|
||||||
|
else:
|
||||||
|
diff = torch.nn.Parameter(
|
||||||
|
torch.zeros(
|
||||||
|
m.weight.shape, dtype=lora_dtype, requires_grad=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
diff_module = BiasDiff(diff)
|
||||||
|
mp.add_weight_wrapper(key, BiasDiff(diff))
|
||||||
|
all_weight_adapters.append(diff_module)
|
||||||
|
lora_sd["{}.diff".format(n)] = diff
|
||||||
|
if hasattr(m, "bias") and m.bias is not None:
|
||||||
|
key = "{}.bias".format(n)
|
||||||
|
bias = torch.nn.Parameter(
|
||||||
|
torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True)
|
||||||
|
)
|
||||||
|
bias_module = BiasDiff(bias)
|
||||||
|
lora_sd["{}.diff_b".format(n)] = bias
|
||||||
|
mp.add_weight_wrapper(key, BiasDiff(bias))
|
||||||
|
all_weight_adapters.append(bias_module)
|
||||||
|
|
||||||
|
if optimizer == "Adam":
|
||||||
|
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
|
||||||
|
elif optimizer == "AdamW":
|
||||||
|
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
|
||||||
|
elif optimizer == "SGD":
|
||||||
|
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
|
||||||
|
elif optimizer == "RMSprop":
|
||||||
|
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
|
||||||
|
|
||||||
|
# Setup loss function based on selection
|
||||||
|
if loss_function == "MSE":
|
||||||
|
criterion = torch.nn.MSELoss()
|
||||||
|
elif loss_function == "L1":
|
||||||
|
criterion = torch.nn.L1Loss()
|
||||||
|
elif loss_function == "Huber":
|
||||||
|
criterion = torch.nn.HuberLoss()
|
||||||
|
elif loss_function == "SmoothL1":
|
||||||
|
criterion = torch.nn.SmoothL1Loss()
|
||||||
|
|
||||||
|
# setup models
|
||||||
|
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||||
|
patch(m)
|
||||||
|
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
||||||
|
|
||||||
|
# Setup sampler and guider like in test script
|
||||||
|
loss_map = {"loss": []}
|
||||||
|
def loss_callback(loss):
|
||||||
|
loss_map["loss"].append(loss)
|
||||||
|
pbar.set_postfix({"loss": f"{loss:.4f}"})
|
||||||
|
train_sampler = TrainSampler(
|
||||||
|
criterion, optimizer, loss_callback=loss_callback
|
||||||
|
)
|
||||||
|
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||||
|
guider.set_conds(positive) # Set conditioning from input
|
||||||
|
ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced()
|
||||||
|
|
||||||
|
# yoland: this currently resize to the first image in the dataset
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
try:
|
||||||
|
for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||||
|
# Generate random sigma
|
||||||
|
sigma = mp.model.model_sampling.percent_to_sigma(
|
||||||
|
torch.rand((1,)).item()
|
||||||
|
)
|
||||||
|
sigma = torch.tensor([sigma])
|
||||||
|
|
||||||
|
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
|
||||||
|
|
||||||
|
indices = torch.randperm(num_images)[:batch_size]
|
||||||
|
ss.sample(
|
||||||
|
noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()}
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
for m in mp.model.modules():
|
||||||
|
unpatch(m)
|
||||||
|
del ss, train_sampler, optimizer
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
for adapter in all_weight_adapters:
|
||||||
|
adapter.requires_grad_(False)
|
||||||
|
|
||||||
|
for param in lora_sd:
|
||||||
|
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
||||||
|
|
||||||
|
return (mp, lora_sd, loss_map, steps + existing_steps)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraModelLoader:
|
||||||
|
def __init__(self):
|
||||||
|
self.loaded_lora = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
|
||||||
|
"lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}),
|
||||||
|
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
|
||||||
|
FUNCTION = "load_lora_model"
|
||||||
|
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
DESCRIPTION = "Load Trained LoRA weights from Train LoRA node."
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def load_lora_model(self, model, lora, strength_model):
|
||||||
|
if strength_model == 0:
|
||||||
|
return (model, )
|
||||||
|
|
||||||
|
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
|
||||||
|
return (model_lora, )
|
||||||
|
|
||||||
|
|
||||||
|
class SaveLoRA:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"lora": (
|
||||||
|
IO.LORA_MODEL,
|
||||||
|
{
|
||||||
|
"tooltip": "The LoRA model to save. Do not use the model with LoRA layers."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"prefix": (
|
||||||
|
"STRING",
|
||||||
|
{
|
||||||
|
"default": "loras/ComfyUI_trained_lora",
|
||||||
|
"tooltip": "The prefix to use for the saved LoRA file.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"steps": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"forceInput": True,
|
||||||
|
"tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
def save(self, lora, prefix, steps=None):
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir)
|
||||||
|
if steps is None:
|
||||||
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
|
else:
|
||||||
|
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
|
||||||
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
safetensors.torch.save_file(lora, output_checkpoint)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class LossGraphNode:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_temp_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"loss": (IO.LOSS_MAP, {"default": {}}),
|
||||||
|
"filename_prefix": (IO.STRING, {"default": "loss_graph"}),
|
||||||
|
},
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "plot_loss"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
CATEGORY = "training"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
DESCRIPTION = "Plots the loss graph and saves it to the output directory."
|
||||||
|
|
||||||
|
def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||||
|
loss_values = loss["loss"]
|
||||||
|
width, height = 800, 480
|
||||||
|
margin = 40
|
||||||
|
|
||||||
|
img = Image.new(
|
||||||
|
"RGB", (width + margin, height + margin), "white"
|
||||||
|
) # Extend canvas
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
min_loss, max_loss = min(loss_values), max(loss_values)
|
||||||
|
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values]
|
||||||
|
|
||||||
|
steps = len(loss_values)
|
||||||
|
|
||||||
|
prev_point = (margin, height - int(scaled_loss[0] * height))
|
||||||
|
for i, l in enumerate(scaled_loss[1:], start=1):
|
||||||
|
x = margin + int(i / steps * width) # Scale X properly
|
||||||
|
y = height - int(l * height)
|
||||||
|
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||||
|
prev_point = (x, y)
|
||||||
|
|
||||||
|
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
|
||||||
|
draw.line(
|
||||||
|
[(margin, height), (width + margin, height)], fill="black", width=2
|
||||||
|
) # X-axis
|
||||||
|
|
||||||
|
font = None
|
||||||
|
try:
|
||||||
|
font = ImageFont.truetype("arial.ttf", 12)
|
||||||
|
except IOError:
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
|
||||||
|
# Add axis labels
|
||||||
|
draw.text((5, height // 2), "Loss", font=font, fill="black")
|
||||||
|
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
|
||||||
|
|
||||||
|
# Add min/max loss values
|
||||||
|
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
|
||||||
|
draw.text(
|
||||||
|
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = None
|
||||||
|
if not args.disable_metadata:
|
||||||
|
metadata = PngInfo()
|
||||||
|
if prompt is not None:
|
||||||
|
metadata.add_text("prompt", json.dumps(prompt))
|
||||||
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
||||||
|
|
||||||
|
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
img.save(
|
||||||
|
os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"),
|
||||||
|
pnginfo=metadata,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"ui": {
|
||||||
|
"images": [
|
||||||
|
{
|
||||||
|
"filename": f"{filename_prefix}_{date}.png",
|
||||||
|
"subfolder": "",
|
||||||
|
"type": "temp",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TrainLoraNode": TrainLoraNode,
|
||||||
|
"SaveLoRANode": SaveLoRA,
|
||||||
|
"LoraModelLoader": LoraModelLoader,
|
||||||
|
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
|
||||||
|
"LossGraphNode": LossGraphNode,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"TrainLoraNode": "Train LoRA",
|
||||||
|
"SaveLoRANode": "Save LoRA Weights",
|
||||||
|
"LoraModelLoader": "Load LoRA Model",
|
||||||
|
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
|
||||||
|
"LossGraphNode": "Plot Loss Graph",
|
||||||
|
}
|
@ -23,6 +23,10 @@ class WebcamCapture(nodes.LoadImage):
|
|||||||
def load_capture(self, image, **kwargs):
|
def load_capture(self, image, **kwargs):
|
||||||
return super().load_image(folder_paths.get_annotated_filepath(image))
|
return super().load_image(folder_paths.get_annotated_filepath(image))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def IS_CHANGED(cls, image, width, height, capture_on_queue):
|
||||||
|
return super().IS_CHANGED(image)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"WebcamCapture": WebcamCapture,
|
"WebcamCapture": WebcamCapture,
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.39"
|
__version__ = "0.3.41"
|
||||||
|
28
execution.py
28
execution.py
@ -1,23 +1,35 @@
|
|||||||
import sys
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
import heapq
|
import heapq
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import inspect
|
|
||||||
from typing import List, Literal, NamedTuple, Optional
|
from typing import List, Literal, NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import nodes
|
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
import nodes
|
||||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
from comfy_execution.caching import (
|
||||||
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
|
CacheKeySetID,
|
||||||
|
CacheKeySetInputSignature,
|
||||||
|
DependencyAwareCache,
|
||||||
|
HierarchicalCache,
|
||||||
|
LRUCache,
|
||||||
|
)
|
||||||
|
from comfy_execution.graph import (
|
||||||
|
DynamicPrompt,
|
||||||
|
ExecutionBlocker,
|
||||||
|
ExecutionList,
|
||||||
|
get_input_info,
|
||||||
|
)
|
||||||
|
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
|
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
SUCCESS = 0
|
SUCCESS = 0
|
||||||
FAILURE = 1
|
FAILURE = 1
|
||||||
|
@ -276,6 +276,9 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
|
|||||||
|
|
||||||
|
|
||||||
def get_full_path(folder_name: str, filename: str) -> str | None:
|
def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||||
|
"""
|
||||||
|
Get the full path of a file in a folder, has to be a file
|
||||||
|
"""
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
folder_name = map_legacy(folder_name)
|
folder_name = map_legacy(folder_name)
|
||||||
if folder_name not in folder_names_and_paths:
|
if folder_name not in folder_names_and_paths:
|
||||||
@ -293,6 +296,9 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the full path of a file in a folder, has to be a file
|
||||||
|
"""
|
||||||
full_path = get_full_path(folder_name, filename)
|
full_path = get_full_path(folder_name, filename)
|
||||||
if full_path is None:
|
if full_path is None:
|
||||||
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
||||||
@ -394,3 +400,26 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
|
|||||||
os.makedirs(full_output_folder, exist_ok=True)
|
os.makedirs(full_output_folder, exist_ok=True)
|
||||||
counter = 1
|
counter = 1
|
||||||
return full_output_folder, filename, counter, subfolder, filename_prefix
|
return full_output_folder, filename, counter, subfolder, filename_prefix
|
||||||
|
|
||||||
|
def get_input_subfolders() -> list[str]:
|
||||||
|
"""Returns a list of all subfolder paths in the input directory, recursively.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of folder paths relative to the input directory, excluding the root directory
|
||||||
|
"""
|
||||||
|
input_dir = get_input_directory()
|
||||||
|
folders = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not os.path.exists(input_dir):
|
||||||
|
return []
|
||||||
|
|
||||||
|
for root, dirs, _ in os.walk(input_dir):
|
||||||
|
rel_path = os.path.relpath(root, input_dir)
|
||||||
|
if rel_path != ".": # Only include non-root directories
|
||||||
|
# Normalize path separators to forward slashes
|
||||||
|
folders.append(rel_path.replace(os.sep, '/'))
|
||||||
|
|
||||||
|
return sorted(folders)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return []
|
||||||
|
14
main.py
14
main.py
@ -17,7 +17,6 @@ if __name__ == "__main__":
|
|||||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||||
os.environ['DO_NOT_TRACK'] = '1'
|
os.environ['DO_NOT_TRACK'] = '1'
|
||||||
|
|
||||||
|
|
||||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||||
|
|
||||||
def apply_custom_paths():
|
def apply_custom_paths():
|
||||||
@ -238,6 +237,15 @@ def cleanup_temp():
|
|||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_database():
|
||||||
|
try:
|
||||||
|
from app.database.db import init_db, dependencies_available
|
||||||
|
if dependencies_available():
|
||||||
|
init_db()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||||
|
|
||||||
|
|
||||||
def start_comfyui(asyncio_loop=None):
|
def start_comfyui(asyncio_loop=None):
|
||||||
"""
|
"""
|
||||||
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
|
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
|
||||||
@ -266,6 +274,7 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
hook_breaker_ac10a0.restore_functions()
|
hook_breaker_ac10a0.restore_functions()
|
||||||
|
|
||||||
cuda_malloc_warning()
|
cuda_malloc_warning()
|
||||||
|
setup_database()
|
||||||
|
|
||||||
prompt_server.add_routes()
|
prompt_server.add_routes()
|
||||||
hijack_progress(prompt_server)
|
hijack_progress(prompt_server)
|
||||||
@ -300,6 +309,9 @@ if __name__ == "__main__":
|
|||||||
logging.info("Python version: {}".format(sys.version))
|
logging.info("Python version: {}".format(sys.version))
|
||||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||||
|
|
||||||
|
if sys.version_info.major == 3 and sys.version_info.minor < 10:
|
||||||
|
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
|
||||||
|
|
||||||
event_loop, _, start_all_func = start_comfyui()
|
event_loop, _, start_all_func = start_comfyui()
|
||||||
try:
|
try:
|
||||||
x = start_all_func()
|
x = start_all_func()
|
||||||
|
21
nodes.py
21
nodes.py
@ -2067,6 +2067,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ImageQuantize": "Image Quantize",
|
"ImageQuantize": "Image Quantize",
|
||||||
"ImageSharpen": "Image Sharpen",
|
"ImageSharpen": "Image Sharpen",
|
||||||
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
||||||
|
"GetImageSize": "Get Image Size",
|
||||||
# _for_testing
|
# _for_testing
|
||||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||||
@ -2124,6 +2125,25 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes
|
|||||||
|
|
||||||
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
|
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from comfy_config import config_parser
|
||||||
|
|
||||||
|
project_config = config_parser.extract_node_configuration(module_path)
|
||||||
|
|
||||||
|
web_dir_name = project_config.tool_comfy.web
|
||||||
|
|
||||||
|
if web_dir_name:
|
||||||
|
web_dir_path = os.path.join(module_path, web_dir_name)
|
||||||
|
|
||||||
|
if os.path.isdir(web_dir_path):
|
||||||
|
project_name = project_config.project.name
|
||||||
|
|
||||||
|
EXTENSION_WEB_DIRS[project_name] = web_dir_path
|
||||||
|
|
||||||
|
logging.info("Automatically register web folder {} for {}".format(web_dir_name, project_name))
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Unable to parse pyproject.toml due to lack dependency pydantic-settings, please run 'pip install -r requirements.txt': {e}")
|
||||||
|
|
||||||
if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None:
|
if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None:
|
||||||
web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY")))
|
web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY")))
|
||||||
if os.path.isdir(web_dir):
|
if os.path.isdir(web_dir):
|
||||||
@ -2211,6 +2231,7 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_model_downscale.py",
|
"nodes_model_downscale.py",
|
||||||
"nodes_images.py",
|
"nodes_images.py",
|
||||||
"nodes_video_model.py",
|
"nodes_video_model.py",
|
||||||
|
"nodes_train.py",
|
||||||
"nodes_sag.py",
|
"nodes_sag.py",
|
||||||
"nodes_perpneg.py",
|
"nodes_perpneg.py",
|
||||||
"nodes_stable3d.py",
|
"nodes_stable3d.py",
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.39"
|
version = "0.3.41"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
comfyui-frontend-package==1.21.3
|
comfyui-frontend-package==1.21.7
|
||||||
comfyui-workflow-templates==0.1.23
|
comfyui-workflow-templates==0.1.28
|
||||||
comfyui-embedded-docs==0.2.0
|
comfyui-embedded-docs==0.2.2
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
@ -18,6 +18,8 @@ Pillow
|
|||||||
scipy
|
scipy
|
||||||
tqdm
|
tqdm
|
||||||
psutil
|
psutil
|
||||||
|
alembic
|
||||||
|
SQLAlchemy
|
||||||
|
|
||||||
#non essential dependencies:
|
#non essential dependencies:
|
||||||
kornia>=0.7.1
|
kornia>=0.7.1
|
||||||
@ -25,3 +27,4 @@ spandrel
|
|||||||
soundfile
|
soundfile
|
||||||
av>=14.2.0
|
av>=14.2.0
|
||||||
pydantic~=2.0
|
pydantic~=2.0
|
||||||
|
pydantic-settings~=2.0
|
||||||
|
@ -390,7 +390,7 @@ class PromptServer():
|
|||||||
async def view_image(request):
|
async def view_image(request):
|
||||||
if "filename" in request.rel_url.query:
|
if "filename" in request.rel_url.query:
|
||||||
filename = request.rel_url.query["filename"]
|
filename = request.rel_url.query["filename"]
|
||||||
filename,output_dir = folder_paths.annotated_filepath(filename)
|
filename, output_dir = folder_paths.annotated_filepath(filename)
|
||||||
|
|
||||||
if not filename:
|
if not filename:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
@ -476,9 +476,8 @@ class PromptServer():
|
|||||||
# Get content type from mimetype, defaulting to 'application/octet-stream'
|
# Get content type from mimetype, defaulting to 'application/octet-stream'
|
||||||
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
||||||
|
|
||||||
# For security, force certain extensions to download instead of display
|
# For security, force certain mimetypes to download instead of display
|
||||||
file_extension = os.path.splitext(filename)[1].lower()
|
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
|
||||||
if file_extension in {'.html', '.htm', '.js', '.css'}:
|
|
||||||
content_type = 'application/octet-stream' # Forces download
|
content_type = 'application/octet-stream' # Forces download
|
||||||
|
|
||||||
return web.FileResponse(
|
return web.FileResponse(
|
||||||
@ -789,7 +788,7 @@ class PromptServer():
|
|||||||
if hasattr(Image, 'Resampling'):
|
if hasattr(Image, 'Resampling'):
|
||||||
resampling = Image.Resampling.BILINEAR
|
resampling = Image.Resampling.BILINEAR
|
||||||
else:
|
else:
|
||||||
resampling = Image.ANTIALIAS
|
resampling = Image.Resampling.LANCZOS
|
||||||
|
|
||||||
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||||
type_num = 1
|
type_num = 1
|
||||||
|
@ -5,7 +5,10 @@ from unittest.mock import patch, MagicMock
|
|||||||
mock_nodes = MagicMock()
|
mock_nodes = MagicMock()
|
||||||
mock_nodes.MAX_RESOLUTION = 16384
|
mock_nodes.MAX_RESOLUTION = 16384
|
||||||
|
|
||||||
with patch.dict('sys.modules', {'nodes': mock_nodes}):
|
# Mock server module for PromptServer
|
||||||
|
mock_server = MagicMock()
|
||||||
|
|
||||||
|
with patch.dict('sys.modules', {'nodes': mock_nodes, 'server': mock_server}):
|
||||||
from comfy_extras.nodes_images import ImageStitch
|
from comfy_extras.nodes_images import ImageStitch
|
||||||
|
|
||||||
|
|
||||||
|
51
tests-unit/folder_paths_test/misc_test.py
Normal file
51
tests-unit/folder_paths_test/misc_test.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from folder_paths import get_input_subfolders, set_input_directory
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mock_folder_structure():
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
# Create a nested folder structure
|
||||||
|
folders = [
|
||||||
|
"folder1",
|
||||||
|
"folder1/subfolder1",
|
||||||
|
"folder1/subfolder2",
|
||||||
|
"folder2",
|
||||||
|
"folder2/deep",
|
||||||
|
"folder2/deep/nested",
|
||||||
|
"empty_folder"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create the folders
|
||||||
|
for folder in folders:
|
||||||
|
os.makedirs(os.path.join(temp_dir, folder))
|
||||||
|
|
||||||
|
# Add some files to test they're not included
|
||||||
|
with open(os.path.join(temp_dir, "root_file.txt"), "w") as f:
|
||||||
|
f.write("test")
|
||||||
|
with open(os.path.join(temp_dir, "folder1", "test.txt"), "w") as f:
|
||||||
|
f.write("test")
|
||||||
|
|
||||||
|
set_input_directory(temp_dir)
|
||||||
|
yield temp_dir
|
||||||
|
|
||||||
|
|
||||||
|
def test_gets_all_folders(mock_folder_structure):
|
||||||
|
folders = get_input_subfolders()
|
||||||
|
expected = ["folder1", "folder1/subfolder1", "folder1/subfolder2",
|
||||||
|
"folder2", "folder2/deep", "folder2/deep/nested", "empty_folder"]
|
||||||
|
assert sorted(folders) == sorted(expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_nonexistent_input_directory():
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
nonexistent = os.path.join(temp_dir, "nonexistent")
|
||||||
|
set_input_directory(nonexistent)
|
||||||
|
assert get_input_subfolders() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_input_directory():
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
set_input_directory(temp_dir)
|
||||||
|
assert get_input_subfolders() == [] # Empty since we don't include root
|
18
utils/install_util.py
Normal file
18
utils/install_util.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# The path to the requirements.txt file
|
||||||
|
requirements_path = Path(__file__).parents[1] / "requirements.txt"
|
||||||
|
|
||||||
|
|
||||||
|
def get_missing_requirements_message():
|
||||||
|
"""The warning message to display when a package is missing."""
|
||||||
|
|
||||||
|
extra = ""
|
||||||
|
if sys.flags.no_user_site:
|
||||||
|
extra = "-s "
|
||||||
|
return f"""
|
||||||
|
Please install the updated requirements.txt file by running:
|
||||||
|
{sys.executable} {extra}-m pip install -r {requirements_path}
|
||||||
|
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem.
|
||||||
|
""".strip()
|
Loading…
x
Reference in New Issue
Block a user