Merge branch 'master' into worksplit-multigpu

This commit is contained in:
Jedrzej Kosinski 2025-05-12 19:29:13 -05:00
commit 9726eac475
121 changed files with 37122 additions and 273 deletions

View File

@ -63,7 +63,12 @@ except:
print("checking out master branch") # noqa: T201
branch = repo.lookup_branch('master')
if branch is None:
ref = repo.lookup_reference('refs/remotes/origin/master')
try:
ref = repo.lookup_reference('refs/remotes/origin/master')
except:
print("pulling.") # noqa: T201
pull(repo)
ref = repo.lookup_reference('refs/remotes/origin/master')
repo.checkout(ref)
branch = repo.lookup_branch('master')
if branch is None:

View File

@ -12,7 +12,7 @@ on:
description: 'CUDA version'
required: true
type: string
default: "126"
default: "128"
python_minor:
description: 'Python minor version'
required: true
@ -22,7 +22,7 @@ on:
description: 'Python patch version'
required: true
type: string
default: "9"
default: "10"
jobs:
@ -36,7 +36,7 @@ jobs:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.git_tag }}
fetch-depth: 0
fetch-depth: 150
persist-credentials: false
- uses: actions/cache/restore@v4
id: cache
@ -70,7 +70,7 @@ jobs:
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable
mv python_embeded ComfyUI_windows_portable
@ -85,12 +85,14 @@ jobs:
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
python_embeded/python.exe -s ./update/update.py ComfyUI/
ls
- name: Upload binaries to release

View File

@ -17,7 +17,7 @@ jobs:
path: "ComfyUI"
- uses: actions/setup-python@v4
with:
python-version: '3.9'
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip

56
.github/workflows/update-api-stubs.yml vendored Normal file
View File

@ -0,0 +1,56 @@
name: Generate Pydantic Stubs from api.comfy.org
on:
schedule:
- cron: '0 0 * * 1'
workflow_dispatch:
jobs:
generate-models:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install 'datamodel-code-generator[http]'
npm install @redocly/cli
- name: Download OpenAPI spec
run: |
curl -o openapi.yaml https://api.comfy.org/openapi
- name: Filter OpenAPI spec with Redocly
run: |
npx @redocly/cli bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
- name: Generate API models
run: |
datamodel-codegen --use-subclass-enum --input filtered-openapi.yaml --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel
- name: Check for changes
id: git-check
run: |
git diff --exit-code comfy_api_nodes/apis || echo "changes=true" >> $GITHUB_OUTPUT
- name: Create Pull Request
if: steps.git-check.outputs.changes == 'true'
uses: peter-evans/create-pull-request@v5
with:
commit-message: 'chore: update API models from OpenAPI spec'
title: 'Update API models from api.comfy.org'
body: |
This PR updates the API models based on the latest api.comfy.org OpenAPI specification.
Generated automatically by the a Github workflow.
branch: update-api-stubs
delete-branch: true
base: master

View File

@ -17,7 +17,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "126"
default: "128"
python_minor:
description: 'python minor version'
@ -29,7 +29,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "9"
default: "10"
# push:
# branches:
# - master

View File

@ -56,7 +56,7 @@ jobs:
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable_nightly_pytorch
mv python_embeded ComfyUI_windows_portable_nightly_pytorch

View File

@ -7,7 +7,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "126"
default: "128"
python_minor:
description: 'python minor version'
@ -19,7 +19,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "9"
default: "10"
# push:
# branches:
# - master
@ -50,7 +50,7 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 0
fetch-depth: 150
persist-credentials: false
- shell: bash
run: |
@ -67,7 +67,7 @@ jobs:
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable
mv python_embeded ComfyUI_windows_portable
@ -82,12 +82,14 @@ jobs:
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
python_embeded/python.exe -s ./update/update.py ComfyUI/
ls
- name: Upload binaries to release

3
.gitignore vendored
View File

@ -21,3 +21,6 @@ venv/
*.log
web_custom_versions/
.DS_Store
openapi.yaml
filtered-openapi.yaml
uv.lock

View File

@ -49,7 +49,6 @@ Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon,
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Image Models
@ -70,9 +69,11 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- 3D Models
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
@ -99,6 +100,23 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
## Release Process
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0)
- Serves as the foundation for the desktop release
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
- Builds a new release using the latest stable core version
- Version numbers match the core release (e.g., Desktop v1.7.0 uses Core v1.7.0)
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
- Weekly frontend updates are merged into the core repository
- Features are frozen for the upcoming core release
- Development continues for the next release cycle
## Shortcuts
| Keybind | Explanation |
@ -149,8 +167,6 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
If you have trouble extracting it, right click the file -> properties -> unblock
If you have a 50 series Blackwell card like a 5090 or 5080 see [this discussion thread](https://github.com/comfyanonymous/ComfyUI/discussions/6643)
#### How do I share models between another UI and ComfyUI?
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
@ -216,9 +232,9 @@ Additional discussion and help can be found [here](https://github.com/comfyanony
Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128```
This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements.
This is the command to install pytorch nightly instead which might have performance improvements.
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```

View File

@ -93,16 +93,20 @@ class CustomNodeManager:
def add_routes(self, routes, webapp, loadedModules):
example_workflow_folder_names = ["example_workflows", "example", "examples", "workflow", "workflows"]
@routes.get("/workflow_templates")
async def get_workflow_templates(request):
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
files = [
file
for folder in folder_paths.get_folder_paths("custom_nodes")
for file in glob.glob(
os.path.join(folder, "*/example_workflows/*.json")
)
]
files = []
for folder in folder_paths.get_folder_paths("custom_nodes"):
for folder_name in example_workflow_folder_names:
pattern = os.path.join(folder, f"*/{folder_name}/*.json")
matched_files = glob.glob(pattern)
files.extend(matched_files)
workflow_templates_dict = (
{}
) # custom_nodes folder name -> example workflow names
@ -118,15 +122,22 @@ class CustomNodeManager:
# Serve workflow templates from custom nodes.
for module_name, module_dir in loadedModules:
workflows_dir = os.path.join(module_dir, "example_workflows")
if os.path.exists(workflows_dir):
webapp.add_routes(
[
web.static(
"/api/workflow_templates/" + module_name, workflows_dir
)
]
)
for folder_name in example_workflow_folder_names:
workflows_dir = os.path.join(module_dir, folder_name)
if os.path.exists(workflows_dir):
if folder_name != "example_workflows":
logging.debug(
"Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'",
folder_name, module_name)
webapp.add_routes(
[
web.static(
"/api/workflow_templates/" + module_name, workflows_dir
)
]
)
@routes.get("/i18n")
async def get_i18n(request):

View File

@ -197,6 +197,112 @@ class UserManager():
return web.json_response(results)
@routes.get("/v2/userdata")
async def list_userdata_v2(request):
"""
List files and directories in a user's data directory.
This endpoint provides a structured listing of contents within a specified
subdirectory of the user's data storage.
Query Parameters:
- path (optional): The relative path within the user's data directory
to list. Defaults to the root ('').
Returns:
- 400: If the requested path is invalid, outside the user's data directory, or is not a directory.
- 404: If the requested path does not exist.
- 403: If the user is invalid.
- 500: If there is an error reading the directory contents.
- 200: JSON response containing a list of file and directory objects.
Each object includes:
- name: The name of the file or directory.
- type: 'file' or 'directory'.
- path: The relative path from the user's data root.
- size (for files): The size in bytes.
- modified (for files): The last modified timestamp (Unix epoch).
"""
requested_rel_path = request.rel_url.query.get('path', '')
# URL-decode the path parameter
try:
requested_rel_path = parse.unquote(requested_rel_path)
except Exception as e:
logging.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}")
return web.Response(status=400, text="Invalid characters in path parameter")
# Check user validity and get the absolute path for the requested directory
try:
base_user_path = self.get_request_user_filepath(request, None, create_dir=False)
if requested_rel_path:
target_abs_path = self.get_request_user_filepath(request, requested_rel_path, create_dir=False)
else:
target_abs_path = base_user_path
except KeyError as e:
# Invalid user detected by get_request_user_id inside get_request_user_filepath
logging.warning(f"Access denied for user: {e}")
return web.Response(status=403, text="Invalid user specified in request")
if not target_abs_path:
# Path traversal or other issue detected by get_request_user_filepath
return web.Response(status=400, text="Invalid path requested")
# Handle cases where the user directory or target path doesn't exist
if not os.path.exists(target_abs_path):
# Check if it's the base user directory that's missing (new user case)
if target_abs_path == base_user_path:
# It's okay if the base user directory doesn't exist yet, return empty list
return web.json_response([])
else:
# A specific subdirectory was requested but doesn't exist
return web.Response(status=404, text="Requested path not found")
if not os.path.isdir(target_abs_path):
return web.Response(status=400, text="Requested path is not a directory")
results = []
try:
for root, dirs, files in os.walk(target_abs_path, topdown=True):
# Process directories
for dir_name in dirs:
dir_path = os.path.join(root, dir_name)
rel_path = os.path.relpath(dir_path, base_user_path).replace(os.sep, '/')
results.append({
"name": dir_name,
"path": rel_path,
"type": "directory"
})
# Process files
for file_name in files:
file_path = os.path.join(root, file_name)
rel_path = os.path.relpath(file_path, base_user_path).replace(os.sep, '/')
entry_info = {
"name": file_name,
"path": rel_path,
"type": "file"
}
try:
stats = os.stat(file_path) # Use os.stat for potentially better performance with os.walk
entry_info["size"] = stats.st_size
entry_info["modified"] = stats.st_mtime
except OSError as stat_error:
logging.warning(f"Could not stat file {file_path}: {stat_error}")
pass # Include file with available info
results.append(entry_info)
except OSError as e:
logging.error(f"Error listing directory {target_abs_path}: {e}")
return web.Response(status=500, text="Error reading directory contents")
# Sort results alphabetically, directories first then files
results.sort(key=lambda x: (x['type'] != 'directory', x['name'].lower()))
return web.json_response(results)
def get_user_data_path(request, check_exists = False, param = "file"):
file = request.match_info.get(param, None)
if not file:

View File

@ -128,6 +128,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
@ -141,12 +142,15 @@ class PerformanceFeature(enum.Enum):
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
@ -191,6 +195,13 @@ parser.add_argument("--user-directory", type=is_valid_directory, default=None, h
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
parser.add_argument(
"--comfy-api-base",
type=str,
default="https://api.comfy.org",
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
)
if comfy.options.args_parsing:
args = parser.parse_args()
else:

View File

@ -18,6 +18,7 @@ class Output:
setattr(self, key, item)
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)

View File

@ -1,7 +1,7 @@
"""Comfy-specific type hinting"""
from __future__ import annotations
from typing import Literal, TypedDict
from typing import Literal, TypedDict, Optional
from typing_extensions import NotRequired
from abc import ABC, abstractmethod
from enum import Enum
@ -48,6 +48,7 @@ class IO(StrEnum):
FACE_ANALYSIS = "FACE_ANALYSIS"
BBOX = "BBOX"
SEGS = "SEGS"
VIDEO = "VIDEO"
ANY = "*"
"""Always matches any type, but at a price.
@ -120,6 +121,10 @@ class InputTypeOptions(TypedDict):
Available from frontend v1.17.5
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548
"""
widgetType: NotRequired[str]
"""Specifies a type to be used for widget initialization if different from the input type.
Available from frontend v1.18.0
https://github.com/Comfy-Org/ComfyUI_frontend/pull/3550"""
# class InputTypeNumber(InputTypeOptions):
# default: float | int
min: NotRequired[float]
@ -229,6 +234,8 @@ class ComfyNodeABC(ABC):
"""Flags a node as experimental, informing users that it may change or not work as expected."""
DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
API_NODE: Optional[bool]
"""Flags a node as an API node."""
@classmethod
@abstractmethod
@ -267,7 +274,7 @@ class ComfyNodeABC(ABC):
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
"""
OUTPUT_IS_LIST: tuple[bool]
OUTPUT_IS_LIST: tuple[bool, ...]
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
@ -286,7 +293,7 @@ class ComfyNodeABC(ABC):
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
"""
RETURN_TYPES: tuple[IO]
RETURN_TYPES: tuple[IO, ...]
"""A tuple representing the outputs of this node.
Usage::
@ -295,12 +302,12 @@ class ComfyNodeABC(ABC):
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types
"""
RETURN_NAMES: tuple[str]
RETURN_NAMES: tuple[str, ...]
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names
"""
OUTPUT_TOOLTIPS: tuple[str]
OUTPUT_TOOLTIPS: tuple[str, ...]
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
FUNCTION: str
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`

View File

@ -116,7 +116,7 @@ class Dino2Embeddings(torch.nn.Module):
def forward(self, pixel_values):
x = self.patch_embeddings(pixel_values)
# TODO: mask_token?
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
return x

View File

@ -1277,6 +1277,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
phi1_fn = lambda t: torch.expm1(t) / t
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
old_sigma_down = None
old_denoised = None
uncond_denoised = None
def post_cfg_function(args):
@ -1304,9 +1305,9 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
x = x + d * dt
else:
# Second order multistep method in https://arxiv.org/pdf/2308.02157
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1])
t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
h = t_next - t
c2 = (t_prev - t) / h
c2 = (t_prev - t_old) / h
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
@ -1326,6 +1327,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
old_denoised = uncond_denoised
else:
old_denoised = denoised
old_sigma_down = sigma_down
return x
@torch.no_grad()
@ -1345,28 +1347,52 @@ def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, cal
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
@torch.no_grad()
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_d = None
uncond_denoised = None
def post_cfg_function(args):
nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"]
if cfg_pp:
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
if cfg_pp:
d = to_d(x, sigmas[i], uncond_denoised)
else:
d = to_d(x, sigmas[i], denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
dt = sigmas[i + 1] - sigmas[i]
if i == 0:
# Euler method
x = x + d * dt
if cfg_pp:
x = denoised + d * sigmas[i + 1]
else:
x = x + d * dt
else:
# Gradient estimation
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
x = x + d_bar * dt
if cfg_pp:
d_bar = (ge_gamma - 1) * (d - old_d)
x = denoised + d * sigmas[i + 1] + d_bar * dt
else:
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
x = x + d_bar * dt
old_d = d
return x
@torch.no_grad()
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
@torch.no_grad()
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
"""

View File

@ -466,3 +466,7 @@ class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 1.0188137142395404
class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2

761
comfy/ldm/ace/attention.py Normal file
View File

@ -0,0 +1,761 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/attention.py
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union, Optional
import torch
import torch.nn.functional as F
from torch import nn
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
kv_heads: Optional[int] = None,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
processor=None,
out_dim: int = None,
out_context_dim: int = None,
context_pre_only=None,
pre_only=False,
elementwise_affine: bool = True,
is_causal: bool = False,
dtype=None, device=None, operations=None
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.use_bias = bias
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.is_causal = is_causal
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.only_cross_attention = only_cross_attention
if self.added_kv_proj_dim is None and self.only_cross_attention:
raise ValueError(
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
)
self.group_norm = None
self.spatial_norm = None
self.norm_q = None
self.norm_k = None
self.norm_cross = None
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
else:
self.to_k = None
self.to_v = None
self.added_proj_bias = added_proj_bias
if self.added_kv_proj_dim is not None:
self.add_k_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
self.add_v_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
if self.context_pre_only is not None:
self.add_q_proj = operations.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias, dtype=dtype, device=device)
else:
self.add_q_proj = None
self.add_k_proj = None
self.add_v_proj = None
if not self.pre_only:
self.to_out = nn.ModuleList([])
self.to_out.append(operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device))
self.to_out.append(nn.Dropout(dropout))
else:
self.to_out = None
if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
else:
self.to_add_out = None
self.norm_added_q = None
self.norm_added_k = None
self.processor = processor
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
class CustomLiteLAProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
def __init__(self):
self.kernel_func = nn.ReLU(inplace=False)
self.eps = 1e-15
self.pad_val = 1.0
def apply_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
hidden_states_len = hidden_states.shape[1]
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
if encoder_hidden_states is not None:
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = hidden_states.shape[0]
# `sample` projections.
dtype = hidden_states.dtype
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# `context` projections.
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# attention
if not attn.is_cross_attention:
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
else:
query = hidden_states
key = encoder_hidden_states
value = encoder_hidden_states
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
# RoPE需要 [B, H, S, D] 输入
# 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S])
# Apply query and key normalization if needed
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if rotary_freqs_cis is not None:
query = self.apply_rotary_emb(query, rotary_freqs_cis)
if not attn.is_cross_attention:
key = self.apply_rotary_emb(key, rotary_freqs_cis)
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
# 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
query = query.permute(0, 1, 3, 2) # [B, H, D, S]
if attention_mask is not None:
# attention_mask: [B, S] -> [B, 1, S, 1]
attention_mask = attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S, 1]
query = query * attention_mask.permute(0, 1, 3, 2) # [B, H, S, D] * [B, 1, S, 1]
if not attn.is_cross_attention:
key = key * attention_mask # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
value = value * attention_mask.permute(0, 1, 3, 2) # 如果 value 是 [B, h, D, S]那么需调整mask以匹配S维度
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S_enc, 1]
# 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1]
value = value * encoder_attention_mask.permute(0, 1, 3, 2) # [B, h, D, S_enc] * [B, 1, 1, S_enc]
query = self.kernel_func(query)
key = self.kernel_func(key)
query, key, value = query.float(), key.float(), value.float()
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
vk = torch.matmul(value, key)
hidden_states = torch.matmul(vk, query)
if hidden_states.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.float()
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
hidden_states = hidden_states.to(dtype)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.to(dtype)
# Split the attention outputs.
if encoder_hidden_states is not None and not attn.is_cross_attention and has_encoder_hidden_state_proj:
hidden_states, encoder_hidden_states = (
hidden_states[:, : hidden_states_len],
hidden_states[:, hidden_states_len:],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None and not attn.context_pre_only and not attn.is_cross_attention and hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if encoder_hidden_states is not None and context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if torch.get_autocast_gpu_dtype() == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return hidden_states, encoder_hidden_states
class CustomerAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def apply_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if rotary_freqs_cis is not None:
query = self.apply_rotary_emb(query, rotary_freqs_cis)
if not attn.is_cross_attention:
key = self.apply_rotary_emb(key, rotary_freqs_cis)
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
# attention_mask: N x S1
# encoder_attention_mask: N x S2
# cross attention 整合attention_mask和encoder_attention_mask
combined_mask = attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
attention_mask = attention_mask[:, None, :, :].expand(-1, attn.heads, -1, -1).to(query.dtype)
elif not attn.is_cross_attention and attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = optimized_attention(
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
).to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
"""Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
if isinstance(x, (list, tuple)):
return list(x)
return [x for _ in range(repeat_time)]
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
"""Return tuple with min_len by repeating element at idx_repeat."""
# convert to list first
x = val2list(x)
# repeat elements if necessary
if len(x) > 0:
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
return tuple(x)
def t2i_modulate(x, shift, scale):
return x * (1 + scale) + shift
def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
if isinstance(kernel_size, tuple):
return tuple([get_same_padding(ks) for ks in kernel_size])
else:
assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
return kernel_size // 2
class ConvLayer(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
kernel_size=3,
stride=1,
dilation=1,
groups=1,
padding: Union[int, None] = None,
use_bias=False,
norm=None,
act=None,
dtype=None, device=None, operations=None
):
super().__init__()
if padding is None:
padding = get_same_padding(kernel_size)
padding *= dilation
self.in_dim = in_dim
self.out_dim = out_dim
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.groups = groups
self.padding = padding
self.use_bias = use_bias
self.conv = operations.Conv1d(
in_dim,
out_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=use_bias,
device=device,
dtype=dtype
)
if norm is not None:
self.norm = operations.RMSNorm(out_dim, elementwise_affine=False, dtype=dtype, device=device)
else:
self.norm = None
if act is not None:
self.act = nn.SiLU(inplace=True)
else:
self.act = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
if self.norm:
x = self.norm(x)
if self.act:
x = self.act(x)
return x
class GLUMBConv(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int,
out_feature=None,
kernel_size=3,
stride=1,
padding: Union[int, None] = None,
use_bias=False,
norm=(None, None, None),
act=("silu", "silu", None),
dilation=1,
dtype=None, device=None, operations=None
):
out_feature = out_feature or in_features
super().__init__()
use_bias = val2tuple(use_bias, 3)
norm = val2tuple(norm, 3)
act = val2tuple(act, 3)
self.glu_act = nn.SiLU(inplace=False)
self.inverted_conv = ConvLayer(
in_features,
hidden_features * 2,
1,
use_bias=use_bias[0],
norm=norm[0],
act=act[0],
dtype=dtype,
device=device,
operations=operations,
)
self.depth_conv = ConvLayer(
hidden_features * 2,
hidden_features * 2,
kernel_size,
stride=stride,
groups=hidden_features * 2,
padding=padding,
use_bias=use_bias[1],
norm=norm[1],
act=None,
dilation=dilation,
dtype=dtype,
device=device,
operations=operations,
)
self.point_conv = ConvLayer(
hidden_features,
out_feature,
1,
use_bias=use_bias[2],
norm=norm[2],
act=act[2],
dtype=dtype,
device=device,
operations=operations,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.transpose(1, 2)
x = self.inverted_conv(x)
x = self.depth_conv(x)
x, gate = torch.chunk(x, 2, dim=1)
gate = self.glu_act(gate)
x = x * gate
x = self.point_conv(x)
x = x.transpose(1, 2)
return x
class LinearTransformerBlock(nn.Module):
"""
A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
"""
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
use_adaln_single=True,
cross_attention_dim=None,
added_kv_proj_dim=None,
context_pre_only=False,
mlp_ratio=4.0,
add_cross_attention=False,
add_cross_attention_dim=None,
qk_norm=None,
dtype=None, device=None, operations=None
):
super().__init__()
self.norm1 = operations.RMSNorm(dim, elementwise_affine=False, eps=1e-6)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
added_kv_proj_dim=added_kv_proj_dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
qk_norm=qk_norm,
processor=CustomLiteLAProcessor2_0(),
dtype=dtype,
device=device,
operations=operations,
)
self.add_cross_attention = add_cross_attention
self.context_pre_only = context_pre_only
if add_cross_attention and add_cross_attention_dim is not None:
self.cross_attn = Attention(
query_dim=dim,
cross_attention_dim=add_cross_attention_dim,
added_kv_proj_dim=add_cross_attention_dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=context_pre_only,
bias=True,
qk_norm=qk_norm,
processor=CustomerAttnProcessor2_0(),
dtype=dtype,
device=device,
operations=operations,
)
self.norm2 = operations.RMSNorm(dim, 1e-06, elementwise_affine=False)
self.ff = GLUMBConv(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
use_bias=(True, True, False),
norm=(None, None, None),
act=("silu", "silu", None),
dtype=dtype,
device=device,
operations=operations,
)
self.use_adaln_single = use_adaln_single
if use_adaln_single:
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, dtype=dtype, device=device))
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: torch.FloatTensor = None,
encoder_attention_mask: torch.FloatTensor = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
temb: torch.FloatTensor = None,
):
N = hidden_states.shape[0]
# step 1: AdaLN single
if self.use_adaln_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
comfy.model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
if self.use_adaln_single:
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
# step 2: attention
if not self.add_cross_attention:
attn_output, encoder_hidden_states = self.attn(
hidden_states=norm_hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
)
else:
attn_output, _ = self.attn(
hidden_states=norm_hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=None,
)
if self.use_adaln_single:
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if self.add_cross_attention:
attn_output = self.cross_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
)
hidden_states = attn_output + hidden_states
# step 3: add norm
norm_hidden_states = self.norm2(hidden_states)
if self.use_adaln_single:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
# step 4: feed forward
ff_output = self.ff(norm_hidden_states)
if self.use_adaln_single:
ff_output = gate_mlp * ff_output
hidden_states = hidden_states + ff_output
return hidden_states

File diff suppressed because it is too large Load Diff

385
comfy/ldm/ace/model.py Normal file
View File

@ -0,0 +1,385 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/ace_step_transformer.py
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, List, Union
import torch
from torch import nn
import comfy.model_management
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from .attention import LinearTransformerBlock, t2i_modulate
from .lyric_encoder import ConformerEncoder as LyricEncoder
def cross_norm(hidden_states, controlnet_input):
# input N x T x c
mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
return controlnet_input
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, dtype=None, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=device).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
class T2IFinalLayer(nn.Module):
"""
The final layer of Sana.
"""
def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True, dtype=dtype, device=device)
self.scale_shift_table = nn.Parameter(torch.empty(2, hidden_size, dtype=dtype, device=device))
self.out_channels = out_channels
self.patch_size = patch_size
def unpatchfy(
self,
hidden_states: torch.Tensor,
width: int,
):
# 4 unpatchify
new_height, new_width = 1, hidden_states.size(1)
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
).contiguous()
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
).contiguous()
if width > new_width:
output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
elif width < new_width:
output = output[:, :, :, :width]
return output
def forward(self, x, t, output_length):
shift, scale = (comfy.model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
# unpatchify
output = self.unpatchfy(x, output_length)
return output
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
height=16,
width=4096,
patch_size=(16, 1),
in_channels=8,
embed_dim=1152,
bias=True,
dtype=None, device=None, operations=None
):
super().__init__()
patch_size_h, patch_size_w = patch_size
self.early_conv_layers = nn.Sequential(
operations.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias, dtype=dtype, device=device),
operations.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True, dtype=dtype, device=device),
operations.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias, dtype=dtype, device=device)
)
self.patch_size = patch_size
self.height, self.width = height // patch_size_h, width // patch_size_w
self.base_size = self.width
def forward(self, latent):
# early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
latent = self.early_conv_layers(latent)
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
return latent
class ACEStepTransformer2DModel(nn.Module):
# _supports_gradient_checkpointing = True
def __init__(
self,
in_channels: Optional[int] = 8,
num_layers: int = 28,
inner_dim: int = 1536,
attention_head_dim: int = 64,
num_attention_heads: int = 24,
mlp_ratio: float = 4.0,
out_channels: int = 8,
max_position: int = 32768,
rope_theta: float = 1000000.0,
speaker_embedding_dim: int = 512,
text_embedding_dim: int = 768,
ssl_encoder_depths: List[int] = [9, 9],
ssl_names: List[str] = ["mert", "m-hubert"],
ssl_latent_dims: List[int] = [1024, 768],
lyric_encoder_vocab_size: int = 6681,
lyric_hidden_size: int = 1024,
patch_size: List[int] = [16, 1],
max_height: int = 16,
max_width: int = 4096,
audio_model=None,
dtype=None, device=None, operations=None
):
super().__init__()
self.dtype = dtype
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = inner_dim
self.out_channels = out_channels
self.max_position = max_position
self.patch_size = patch_size
self.rope_theta = rope_theta
self.rotary_emb = Qwen2RotaryEmbedding(
dim=self.attention_head_dim,
max_position_embeddings=self.max_position,
base=self.rope_theta,
dtype=dtype,
device=device,
)
# 2. Define input layers
self.in_channels = in_channels
self.num_layers = num_layers
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
LinearTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.num_attention_heads,
attention_head_dim=attention_head_dim,
mlp_ratio=mlp_ratio,
add_cross_attention=True,
add_cross_attention_dim=self.inner_dim,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(self.num_layers)
]
)
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim, dtype=dtype, device=device, operations=operations)
self.t_block = nn.Sequential(nn.SiLU(), operations.Linear(self.inner_dim, 6 * self.inner_dim, bias=True, dtype=dtype, device=device))
# speaker
self.speaker_embedder = operations.Linear(speaker_embedding_dim, self.inner_dim, dtype=dtype, device=device)
# genre
self.genre_embedder = operations.Linear(text_embedding_dim, self.inner_dim, dtype=dtype, device=device)
# lyric
self.lyric_embs = operations.Embedding(lyric_encoder_vocab_size, lyric_hidden_size, dtype=dtype, device=device)
self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0, dtype=dtype, device=device, operations=operations)
self.lyric_proj = operations.Linear(lyric_hidden_size, self.inner_dim, dtype=dtype, device=device)
projector_dim = 2 * self.inner_dim
self.projectors = nn.ModuleList([
nn.Sequential(
operations.Linear(self.inner_dim, projector_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(projector_dim, projector_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(projector_dim, ssl_dim, dtype=dtype, device=device),
) for ssl_dim in ssl_latent_dims
])
self.proj_in = PatchEmbed(
height=max_height,
width=max_width,
patch_size=patch_size,
embed_dim=self.inner_dim,
bias=True,
dtype=dtype,
device=device,
operations=operations,
)
self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels, dtype=dtype, device=device, operations=operations)
def forward_lyric_encoder(
self,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
out_dtype=None,
):
# N x T x D
lyric_embs = self.lyric_embs(lyric_token_idx, out_dtype=out_dtype)
prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
return prompt_prenet_out
def encode(
self,
encoder_text_hidden_states: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
lyrics_strength=1.0,
):
bs = encoder_text_hidden_states.shape[0]
device = encoder_text_hidden_states.device
# speaker embedding
encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
# genre embedding
encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
# lyric
encoder_lyric_hidden_states = self.forward_lyric_encoder(
lyric_token_idx=lyric_token_idx,
lyric_mask=lyric_mask,
out_dtype=encoder_text_hidden_states.dtype,
)
encoder_lyric_hidden_states *= lyrics_strength
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
encoder_hidden_mask = None
if text_attention_mask is not None:
speaker_mask = torch.ones(bs, 1, device=device)
encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
return encoder_hidden_states, encoder_hidden_mask
def decode(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_mask: torch.Tensor,
timestep: Optional[torch.Tensor],
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
):
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
temb = self.t_block(embedded_timestep)
hidden_states = self.proj_in(hidden_states)
# controlnet logic
if block_controlnet_hidden_states is not None:
control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
hidden_states = hidden_states + control_condi * controlnet_scale
# inner_hidden_states = []
rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
for index_block, block in enumerate(self.transformer_blocks):
hidden_states = block(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_hidden_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
temb=temb,
)
output = self.final_layer(hidden_states, embedded_timestep, output_length)
return output
def forward(
self,
x,
timestep,
attention_mask=None,
context: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
lyrics_strength=1.0,
**kwargs
):
hidden_states = x
encoder_text_hidden_states = context
encoder_hidden_states, encoder_hidden_mask = self.encode(
encoder_text_hidden_states=encoder_text_hidden_states,
text_attention_mask=text_attention_mask,
speaker_embeds=speaker_embeds,
lyric_token_idx=lyric_token_idx,
lyric_mask=lyric_mask,
lyrics_strength=lyrics_strength,
)
output_length = hidden_states.shape[-1]
output = self.decode(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_mask=encoder_hidden_mask,
timestep=timestep,
output_length=output_length,
block_controlnet_hidden_states=block_controlnet_hidden_states,
controlnet_scale=controlnet_scale,
)
return output

View File

@ -0,0 +1,644 @@
# Rewritten from diffusers
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Union
import comfy.model_management
import comfy.ops
ops = comfy.ops.disable_weight_init
class RMSNorm(ops.RMSNorm):
def __init__(self, dim, eps=1e-5, elementwise_affine=True, bias=False):
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
if elementwise_affine:
self.bias = nn.Parameter(torch.empty(dim)) if bias else None
def forward(self, x):
x = super().forward(x)
if self.elementwise_affine:
if self.bias is not None:
x = x + comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device)
return x
def get_normalization(norm_type, num_features, num_groups=32, eps=1e-5):
if norm_type == "batch_norm":
return nn.BatchNorm2d(num_features)
elif norm_type == "group_norm":
return ops.GroupNorm(num_groups, num_features)
elif norm_type == "layer_norm":
return ops.LayerNorm(num_features)
elif norm_type == "rms_norm":
return RMSNorm(num_features, eps=eps, elementwise_affine=True, bias=True)
else:
raise ValueError(f"Unknown normalization type: {norm_type}")
def get_activation(activation_type):
if activation_type == "relu":
return nn.ReLU()
elif activation_type == "relu6":
return nn.ReLU6()
elif activation_type == "silu":
return nn.SiLU()
elif activation_type == "leaky_relu":
return nn.LeakyReLU(0.2)
else:
raise ValueError(f"Unknown activation type: {activation_type}")
class ResBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
norm_type: str = "batch_norm",
act_fn: str = "relu6",
) -> None:
super().__init__()
self.norm_type = norm_type
self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
self.conv1 = ops.Conv2d(in_channels, in_channels, 3, 1, 1)
self.conv2 = ops.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
self.norm = get_normalization(norm_type, out_channels)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.conv1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.norm_type == "rms_norm":
# move channel to the last dimension so we apply RMSnorm across channel dimension
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
else:
hidden_states = self.norm(hidden_states)
return hidden_states + residual
class SanaMultiscaleAttentionProjection(nn.Module):
def __init__(
self,
in_channels: int,
num_attention_heads: int,
kernel_size: int,
) -> None:
super().__init__()
channels = 3 * in_channels
self.proj_in = ops.Conv2d(
channels,
channels,
kernel_size,
padding=kernel_size // 2,
groups=channels,
bias=False,
)
self.proj_out = ops.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_out(hidden_states)
return hidden_states
class SanaMultiscaleLinearAttention(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_attention_heads: int = None,
attention_head_dim: int = 8,
mult: float = 1.0,
norm_type: str = "batch_norm",
kernel_sizes: tuple = (5,),
eps: float = 1e-15,
residual_connection: bool = False,
):
super().__init__()
self.eps = eps
self.attention_head_dim = attention_head_dim
self.norm_type = norm_type
self.residual_connection = residual_connection
num_attention_heads = (
int(in_channels // attention_head_dim * mult)
if num_attention_heads is None
else num_attention_heads
)
inner_dim = num_attention_heads * attention_head_dim
self.to_q = ops.Linear(in_channels, inner_dim, bias=False)
self.to_k = ops.Linear(in_channels, inner_dim, bias=False)
self.to_v = ops.Linear(in_channels, inner_dim, bias=False)
self.to_qkv_multiscale = nn.ModuleList()
for kernel_size in kernel_sizes:
self.to_qkv_multiscale.append(
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
)
self.nonlinearity = nn.ReLU()
self.to_out = ops.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
self.norm_out = get_normalization(norm_type, out_channels)
def apply_linear_attention(self, query, key, value):
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1)
scores = torch.matmul(value, key.transpose(-1, -2))
hidden_states = torch.matmul(scores, query)
hidden_states = hidden_states.to(dtype=torch.float32)
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
return hidden_states
def apply_quadratic_attention(self, query, key, value):
scores = torch.matmul(key.transpose(-1, -2), query)
scores = scores.to(dtype=torch.float32)
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
hidden_states = torch.matmul(value, scores.to(value.dtype))
return hidden_states
def forward(self, hidden_states):
height, width = hidden_states.shape[-2:]
if height * width > self.attention_head_dim:
use_linear_attention = True
else:
use_linear_attention = False
residual = hidden_states
batch_size, _, height, width = list(hidden_states.size())
original_dtype = hidden_states.dtype
hidden_states = hidden_states.movedim(1, -1)
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
hidden_states = torch.cat([query, key, value], dim=3)
hidden_states = hidden_states.movedim(-1, 1)
multi_scale_qkv = [hidden_states]
for block in self.to_qkv_multiscale:
multi_scale_qkv.append(block(hidden_states))
hidden_states = torch.cat(multi_scale_qkv, dim=1)
if use_linear_attention:
# for linear attention upcast hidden_states to float32
hidden_states = hidden_states.to(dtype=torch.float32)
hidden_states = hidden_states.reshape(batch_size, -1, 3 * self.attention_head_dim, height * width)
query, key, value = hidden_states.chunk(3, dim=2)
query = self.nonlinearity(query)
key = self.nonlinearity(key)
if use_linear_attention:
hidden_states = self.apply_linear_attention(query, key, value)
hidden_states = hidden_states.to(dtype=original_dtype)
else:
hidden_states = self.apply_quadratic_attention(query, key, value)
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
hidden_states = self.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.norm_type == "rms_norm":
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
else:
hidden_states = self.norm_out(hidden_states)
if self.residual_connection:
hidden_states = hidden_states + residual
return hidden_states
class EfficientViTBlock(nn.Module):
def __init__(
self,
in_channels: int,
mult: float = 1.0,
attention_head_dim: int = 32,
qkv_multiscales: tuple = (5,),
norm_type: str = "batch_norm",
) -> None:
super().__init__()
self.attn = SanaMultiscaleLinearAttention(
in_channels=in_channels,
out_channels=in_channels,
mult=mult,
attention_head_dim=attention_head_dim,
norm_type=norm_type,
kernel_sizes=qkv_multiscales,
residual_connection=True,
)
self.conv_out = GLUMBConv(
in_channels=in_channels,
out_channels=in_channels,
norm_type="rms_norm",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.attn(x)
x = self.conv_out(x)
return x
class GLUMBConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
expand_ratio: float = 4,
norm_type: str = None,
residual_connection: bool = True,
) -> None:
super().__init__()
hidden_channels = int(expand_ratio * in_channels)
self.norm_type = norm_type
self.residual_connection = residual_connection
self.nonlinearity = nn.SiLU()
self.conv_inverted = ops.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
self.conv_depth = ops.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
self.conv_point = ops.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
self.norm = None
if norm_type == "rms_norm":
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.residual_connection:
residual = hidden_states
hidden_states = self.conv_inverted(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv_depth(hidden_states)
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
hidden_states = hidden_states * self.nonlinearity(gate)
hidden_states = self.conv_point(hidden_states)
if self.norm_type == "rms_norm":
# move channel to the last dimension so we apply RMSnorm across channel dimension
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.residual_connection:
hidden_states = hidden_states + residual
return hidden_states
def get_block(
block_type: str,
in_channels: int,
out_channels: int,
attention_head_dim: int,
norm_type: str,
act_fn: str,
qkv_mutliscales: tuple = (),
):
if block_type == "ResBlock":
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
elif block_type == "EfficientViTBlock":
block = EfficientViTBlock(
in_channels,
attention_head_dim=attention_head_dim,
norm_type=norm_type,
qkv_multiscales=qkv_mutliscales
)
else:
raise ValueError(f"Block with {block_type=} is not supported.")
return block
class DCDownBlock2d(nn.Module):
def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None:
super().__init__()
self.downsample = downsample
self.factor = 2
self.stride = 1 if downsample else 2
self.group_size = in_channels * self.factor**2 // out_channels
self.shortcut = shortcut
out_ratio = self.factor**2
if downsample:
assert out_channels % out_ratio == 0
out_channels = out_channels // out_ratio
self.conv = ops.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=self.stride,
padding=1,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = self.conv(hidden_states)
if self.downsample:
x = F.pixel_unshuffle(x, self.factor)
if self.shortcut:
y = F.pixel_unshuffle(hidden_states, self.factor)
y = y.unflatten(1, (-1, self.group_size))
y = y.mean(dim=2)
hidden_states = x + y
else:
hidden_states = x
return hidden_states
class DCUpBlock2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
interpolate: bool = False,
shortcut: bool = True,
interpolation_mode: str = "nearest",
) -> None:
super().__init__()
self.interpolate = interpolate
self.interpolation_mode = interpolation_mode
self.shortcut = shortcut
self.factor = 2
self.repeats = out_channels * self.factor**2 // in_channels
out_ratio = self.factor**2
if not interpolate:
out_channels = out_channels * out_ratio
self.conv = ops.Conv2d(in_channels, out_channels, 3, 1, 1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.interpolate:
x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
x = self.conv(x)
else:
x = self.conv(hidden_states)
x = F.pixel_shuffle(x, self.factor)
if self.shortcut:
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
y = F.pixel_shuffle(y, self.factor)
hidden_states = x + y
else:
hidden_states = x
return hidden_states
class Encoder(nn.Module):
def __init__(
self,
in_channels: int,
latent_channels: int,
attention_head_dim: int = 32,
block_type: str or tuple = "ResBlock",
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
downsample_block_type: str = "pixel_unshuffle",
out_shortcut: bool = True,
):
super().__init__()
num_blocks = len(block_out_channels)
if isinstance(block_type, str):
block_type = (block_type,) * num_blocks
if layers_per_block[0] > 0:
self.conv_in = ops.Conv2d(
in_channels,
block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
kernel_size=3,
stride=1,
padding=1,
)
else:
self.conv_in = DCDownBlock2d(
in_channels=in_channels,
out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
downsample=downsample_block_type == "pixel_unshuffle",
shortcut=False,
)
down_blocks = []
for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)):
down_block_list = []
for _ in range(num_layers):
block = get_block(
block_type[i],
out_channel,
out_channel,
attention_head_dim=attention_head_dim,
norm_type="rms_norm",
act_fn="silu",
qkv_mutliscales=qkv_multiscales[i],
)
down_block_list.append(block)
if i < num_blocks - 1 and num_layers > 0:
downsample_block = DCDownBlock2d(
in_channels=out_channel,
out_channels=block_out_channels[i + 1],
downsample=downsample_block_type == "pixel_unshuffle",
shortcut=True,
)
down_block_list.append(downsample_block)
down_blocks.append(nn.Sequential(*down_block_list))
self.down_blocks = nn.ModuleList(down_blocks)
self.conv_out = ops.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1)
self.out_shortcut = out_shortcut
if out_shortcut:
self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
if self.out_shortcut:
x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size))
x = x.mean(dim=2)
hidden_states = self.conv_out(hidden_states) + x
else:
hidden_states = self.conv_out(hidden_states)
return hidden_states
class Decoder(nn.Module):
def __init__(
self,
in_channels: int,
latent_channels: int,
attention_head_dim: int = 32,
block_type: str or tuple = "ResBlock",
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
norm_type: str or tuple = "rms_norm",
act_fn: str or tuple = "silu",
upsample_block_type: str = "pixel_shuffle",
in_shortcut: bool = True,
):
super().__init__()
num_blocks = len(block_out_channels)
if isinstance(block_type, str):
block_type = (block_type,) * num_blocks
if isinstance(norm_type, str):
norm_type = (norm_type,) * num_blocks
if isinstance(act_fn, str):
act_fn = (act_fn,) * num_blocks
self.conv_in = ops.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1)
self.in_shortcut = in_shortcut
if in_shortcut:
self.in_shortcut_repeats = block_out_channels[-1] // latent_channels
up_blocks = []
for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
up_block_list = []
if i < num_blocks - 1 and num_layers > 0:
upsample_block = DCUpBlock2d(
block_out_channels[i + 1],
out_channel,
interpolate=upsample_block_type == "interpolate",
shortcut=True,
)
up_block_list.append(upsample_block)
for _ in range(num_layers):
block = get_block(
block_type[i],
out_channel,
out_channel,
attention_head_dim=attention_head_dim,
norm_type=norm_type[i],
act_fn=act_fn[i],
qkv_mutliscales=qkv_multiscales[i],
)
up_block_list.append(block)
up_blocks.insert(0, nn.Sequential(*up_block_list))
self.up_blocks = nn.ModuleList(up_blocks)
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
self.conv_act = nn.ReLU()
self.conv_out = None
if layers_per_block[0] > 0:
self.conv_out = ops.Conv2d(channels, in_channels, 3, 1, 1)
else:
self.conv_out = DCUpBlock2d(
channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.in_shortcut:
x = hidden_states.repeat_interleave(
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
)
hidden_states = self.conv_in(hidden_states) + x
else:
hidden_states = self.conv_in(hidden_states)
for up_block in reversed(self.up_blocks):
hidden_states = up_block(hidden_states)
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderDC(nn.Module):
def __init__(
self,
in_channels: int = 2,
latent_channels: int = 8,
attention_head_dim: int = 32,
encoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
decoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
encoder_layers_per_block: Tuple[int] = (2, 2, 3, 3),
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3),
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
upsample_block_type: str = "interpolate",
downsample_block_type: str = "Conv",
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
decoder_act_fns: Union[str, Tuple[str]] = "silu",
scaling_factor: float = 0.41407,
) -> None:
super().__init__()
self.encoder = Encoder(
in_channels=in_channels,
latent_channels=latent_channels,
attention_head_dim=attention_head_dim,
block_type=encoder_block_types,
block_out_channels=encoder_block_out_channels,
layers_per_block=encoder_layers_per_block,
qkv_multiscales=encoder_qkv_multiscales,
downsample_block_type=downsample_block_type,
)
self.decoder = Decoder(
in_channels=in_channels,
latent_channels=latent_channels,
attention_head_dim=attention_head_dim,
block_type=decoder_block_types,
block_out_channels=decoder_block_out_channels,
layers_per_block=decoder_layers_per_block,
qkv_multiscales=decoder_qkv_multiscales,
norm_type=decoder_norm_types,
act_fn=decoder_act_fns,
upsample_block_type=upsample_block_type,
)
self.scaling_factor = scaling_factor
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Internal encoding function."""
encoded = self.encoder(x)
return encoded * self.scaling_factor
def decode(self, z: torch.Tensor) -> torch.Tensor:
# Scale the latents back
z = z / self.scaling_factor
decoded = self.decoder(z)
return decoded
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = self.encode(x)
return self.decode(z)

View File

@ -0,0 +1,109 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py
import torch
from .autoencoder_dc import AutoencoderDC
import logging
try:
import torchaudio
except:
logging.warning("torchaudio missing, ACE model will be broken")
import torchvision.transforms as transforms
from .music_vocoder import ADaMoSHiFiGANV1
class MusicDCAE(torch.nn.Module):
def __init__(self, source_sample_rate=None, dcae_config={}, vocoder_config={}):
super(MusicDCAE, self).__init__()
self.dcae = AutoencoderDC(**dcae_config)
self.vocoder = ADaMoSHiFiGANV1(**vocoder_config)
if source_sample_rate is None:
self.source_sample_rate = 48000
else:
self.source_sample_rate = source_sample_rate
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
self.transform = transforms.Compose([
transforms.Normalize(0.5, 0.5),
])
self.min_mel_value = -11.0
self.max_mel_value = 3.0
self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
self.mel_chunk_size = 1024
self.time_dimention_multiple = 8
self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
self.scale_factor = 0.1786
self.shift_factor = -1.9091
def load_audio(self, audio_path):
audio, sr = torchaudio.load(audio_path)
return audio, sr
def forward_mel(self, audios):
mels = []
for i in range(len(audios)):
image = self.vocoder.mel_transform(audios[i])
mels.append(image)
mels = torch.stack(mels)
return mels
@torch.no_grad()
def encode(self, audios, audio_lengths=None, sr=None):
if audio_lengths is None:
audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
audio_lengths = audio_lengths.to(audios.device)
if sr is None:
sr = self.source_sample_rate
if sr != 44100:
audios = torchaudio.functional.resample(audios, sr, 44100)
max_audio_len = audios.shape[-1]
if max_audio_len % (8 * 512) != 0:
audios = torch.nn.functional.pad(audios, (0, 8 * 512 - max_audio_len % (8 * 512)))
mels = self.forward_mel(audios)
mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
mels = self.transform(mels)
latents = []
for mel in mels:
latent = self.dcae.encoder(mel.unsqueeze(0))
latents.append(latent)
latents = torch.cat(latents, dim=0)
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
latents = (latents - self.shift_factor) * self.scale_factor
return latents
# return latents, latent_lengths
@torch.no_grad()
def decode(self, latents, audio_lengths=None, sr=None):
latents = latents / self.scale_factor + self.shift_factor
pred_wavs = []
for latent in latents:
mels = self.dcae.decoder(latent.unsqueeze(0))
mels = mels * 0.5 + 0.5
mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
wav = self.vocoder.decode(mels[0]).squeeze(1)
if sr is not None:
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
wav = torchaudio.functional.resample(wav, 44100, sr)
# wav = resampler(wav)
else:
sr = 44100
pred_wavs.append(wav)
if audio_lengths is not None:
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
return torch.stack(pred_wavs)
# return sr, pred_wavs
def forward(self, audios, audio_lengths=None, sr=None):
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
return sr, pred_wavs, latents, latent_lengths

View File

@ -0,0 +1,113 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py
import torch
import torch.nn as nn
from torch import Tensor
import logging
try:
from torchaudio.transforms import MelScale
except:
logging.warning("torchaudio missing, ACE model will be broken")
import comfy.model_management
class LinearSpectrogram(nn.Module):
def __init__(
self,
n_fft=2048,
win_length=2048,
hop_length=512,
center=False,
mode="pow2_sqrt",
):
super().__init__()
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.mode = mode
self.register_buffer("window", torch.hann_window(win_length))
def forward(self, y: Tensor) -> Tensor:
if y.ndim == 3:
y = y.squeeze(1)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(
(self.win_length - self.hop_length) // 2,
(self.win_length - self.hop_length + 1) // 2,
),
mode="reflect",
).squeeze(1)
dtype = y.dtype
spec = torch.stft(
y.float(),
self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=comfy.model_management.cast_to(self.window, dtype=torch.float32, device=y.device),
center=self.center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.view_as_real(spec)
if self.mode == "pow2_sqrt":
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = spec.to(dtype)
return spec
class LogMelSpectrogram(nn.Module):
def __init__(
self,
sample_rate=44100,
n_fft=2048,
win_length=2048,
hop_length=512,
n_mels=128,
center=False,
f_min=0.0,
f_max=None,
):
super().__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.n_mels = n_mels
self.f_min = f_min
self.f_max = f_max or sample_rate // 2
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
self.mel_scale = MelScale(
self.n_mels,
self.sample_rate,
self.f_min,
self.f_max,
self.n_fft // 2 + 1,
"slaney",
"slaney",
)
def compress(self, x: Tensor) -> Tensor:
return torch.log(torch.clamp(x, min=1e-5))
def decompress(self, x: Tensor) -> Tensor:
return torch.exp(x)
def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
linear = self.spectrogram(x)
x = self.mel_scale(linear)
x = self.compress(x)
# print(x.shape)
if return_linear:
return x, self.compress(linear)
return x

View File

@ -0,0 +1,538 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_vocoder.py
import torch
from torch import nn
from functools import partial
from math import prod
from typing import Callable, Tuple, List
import numpy as np
import torch.nn.functional as F
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
from .music_log_mel import LogMelSpectrogram
import comfy.model_management
import comfy.ops
ops = comfy.ops.disable_weight_init
def drop_path(
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
""" # noqa: E501
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f"drop_prob={round(self.drop_prob,3):0.3f}"
class LayerNorm(nn.Module):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
""" # noqa: E501
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(
x, self.normalized_shape, comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device), comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device), self.eps
)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = comfy.model_management.cast_to(self.weight[:, None], dtype=x.dtype, device=x.device) * x + comfy.model_management.cast_to(self.bias[:, None], dtype=x.dtype, device=x.device)
return x
class ConvNeXtBlock(nn.Module):
r"""ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
kernel_size (int): Kernel size for depthwise conv. Default: 7.
dilation (int): Dilation for depthwise conv. Default: 1.
""" # noqa: E501
def __init__(
self,
dim: int,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-6,
mlp_ratio: float = 4.0,
kernel_size: int = 7,
dilation: int = 1,
):
super().__init__()
self.dwconv = ops.Conv1d(
dim,
dim,
kernel_size=kernel_size,
padding=int(dilation * (kernel_size - 1) / 2),
groups=dim,
) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = ops.Linear(
dim, int(mlp_ratio * dim)
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = ops.Linear(int(mlp_ratio * dim), dim)
self.gamma = (
nn.Parameter(torch.empty((dim)), requires_grad=False)
if layer_scale_init_value > 0
else None
)
self.drop_path = DropPath(
drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x, apply_residual: bool = True):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) * x
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
x = self.drop_path(x)
if apply_residual:
x = input + x
return x
class ParallelConvNeXtBlock(nn.Module):
def __init__(self, kernel_sizes: List[int], *args, **kwargs):
super().__init__()
self.blocks = nn.ModuleList(
[
ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
for kernel_size in kernel_sizes
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.stack(
[block(x, apply_residual=False) for block in self.blocks] + [x],
dim=1,
).sum(dim=1)
class ConvNeXtEncoder(nn.Module):
def __init__(
self,
input_channels=3,
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
drop_path_rate=0.0,
layer_scale_init_value=1e-6,
kernel_sizes: Tuple[int] = (7,),
):
super().__init__()
assert len(depths) == len(dims)
self.channel_layers = nn.ModuleList()
stem = nn.Sequential(
ops.Conv1d(
input_channels,
dims[0],
kernel_size=7,
padding=3,
padding_mode="replicate",
),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
)
self.channel_layers.append(stem)
for i in range(len(depths) - 1):
mid_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
ops.Conv1d(dims[i], dims[i + 1], kernel_size=1),
)
self.channel_layers.append(mid_layer)
block_fn = (
partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
if len(kernel_sizes) == 1
else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
)
self.stages = nn.ModuleList()
drop_path_rates = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
]
cur = 0
for i in range(len(depths)):
stage = nn.Sequential(
*[
block_fn(
dim=dims[i],
drop_path=drop_path_rates[cur + j],
layer_scale_init_value=layer_scale_init_value,
)
for j in range(depths[i])
]
)
self.stages.append(stage)
cur += depths[i]
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
for channel_layer, stage in zip(self.channel_layers, self.stages):
x = channel_layer(x)
x = stage(x)
return self.norm(x)
def get_padding(kernel_size, dilation=1):
return (kernel_size * dilation - dilation) // 2
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super().__init__()
self.convs1 = nn.ModuleList(
[
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs2 = nn.ModuleList(
[
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.silu(x)
xt = c1(xt)
xt = F.silu(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for conv in self.convs1:
remove_weight_norm(conv)
for conv in self.convs2:
remove_weight_norm(conv)
class HiFiGANGenerator(nn.Module):
def __init__(
self,
*,
hop_length: int = 512,
upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
resblock_dilation_sizes: Tuple[Tuple[int]] = (
(1, 3, 5), (1, 3, 5), (1, 3, 5)),
num_mels: int = 128,
upsample_initial_channel: int = 512,
use_template: bool = True,
pre_conv_kernel_size: int = 7,
post_conv_kernel_size: int = 7,
post_activation: Callable = partial(nn.SiLU, inplace=True),
):
super().__init__()
assert (
prod(upsample_rates) == hop_length
), f"hop_length must be {prod(upsample_rates)}"
self.conv_pre = torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
num_mels,
upsample_initial_channel,
pre_conv_kernel_size,
1,
padding=get_padding(pre_conv_kernel_size),
)
)
self.num_upsamples = len(upsample_rates)
self.num_kernels = len(resblock_kernel_sizes)
self.noise_convs = nn.ModuleList()
self.use_template = use_template
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
c_cur = upsample_initial_channel // (2 ** (i + 1))
self.ups.append(
torch.nn.utils.parametrizations.weight_norm(
ops.ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
if not use_template:
continue
if i + 1 < len(upsample_rates):
stride_f0 = np.prod(upsample_rates[i + 1:])
self.noise_convs.append(
ops.Conv1d(
1,
c_cur,
kernel_size=stride_f0 * 2,
stride=stride_f0,
padding=stride_f0 // 2,
)
)
else:
self.noise_convs.append(ops.Conv1d(1, c_cur, kernel_size=1))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
self.resblocks.append(ResBlock1(ch, k, d))
self.activation_post = post_activation()
self.conv_post = torch.nn.utils.parametrizations.weight_norm(
ops.Conv1d(
ch,
1,
post_conv_kernel_size,
1,
padding=get_padding(post_conv_kernel_size),
)
)
def forward(self, x, template=None):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.silu(x, inplace=True)
x = self.ups[i](x)
if self.use_template:
x = x + self.noise_convs[i](template)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
for up in self.ups:
remove_weight_norm(up)
for block in self.resblocks:
block.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class ADaMoSHiFiGANV1(nn.Module):
def __init__(
self,
input_channels: int = 128,
depths: List[int] = [3, 3, 9, 3],
dims: List[int] = [128, 256, 384, 512],
drop_path_rate: float = 0.0,
kernel_sizes: Tuple[int] = (7,),
upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
resblock_dilation_sizes: Tuple[Tuple[int]] = (
(1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
num_mels: int = 512,
upsample_initial_channel: int = 1024,
use_template: bool = False,
pre_conv_kernel_size: int = 13,
post_conv_kernel_size: int = 13,
sampling_rate: int = 44100,
n_fft: int = 2048,
win_length: int = 2048,
hop_length: int = 512,
f_min: int = 40,
f_max: int = 16000,
n_mels: int = 128,
):
super().__init__()
self.backbone = ConvNeXtEncoder(
input_channels=input_channels,
depths=depths,
dims=dims,
drop_path_rate=drop_path_rate,
kernel_sizes=kernel_sizes,
)
self.head = HiFiGANGenerator(
hop_length=hop_length,
upsample_rates=upsample_rates,
upsample_kernel_sizes=upsample_kernel_sizes,
resblock_kernel_sizes=resblock_kernel_sizes,
resblock_dilation_sizes=resblock_dilation_sizes,
num_mels=num_mels,
upsample_initial_channel=upsample_initial_channel,
use_template=use_template,
pre_conv_kernel_size=pre_conv_kernel_size,
post_conv_kernel_size=post_conv_kernel_size,
)
self.sampling_rate = sampling_rate
self.mel_transform = LogMelSpectrogram(
sample_rate=sampling_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
f_min=f_min,
f_max=f_max,
n_mels=n_mels,
)
self.eval()
@torch.no_grad()
def decode(self, mel):
y = self.backbone(mel)
y = self.head(y)
return y
@torch.no_grad()
def encode(self, x):
return self.mel_transform(x)
def forward(self, mel):
y = self.backbone(mel)
y = self.head(y)
return y

View File

@ -75,16 +75,10 @@ class SnakeBeta(nn.Module):
return x
def WNConv1d(*args, **kwargs):
try:
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
except:
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
try:
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
except:
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
if activation == "elu":

183
comfy/ldm/chroma/layers.py Normal file
View File

@ -0,0 +1,183 @@
import torch
from torch import Tensor, nn
from comfy.ldm.flux.math import attention
from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
QKNorm,
SelfAttention,
ModulationOut,
)
class ChromaModulationOut(ModulationOut):
@classmethod
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
return cls(
shift=tensor[:, offset : offset + 1, :],
scale=tensor[:, offset + 1 : offset + 2, :],
gate=tensor[:, offset + 2 : offset + 3, :],
)
class Approximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
super().__init__()
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device
def forward(self, x: Tensor) -> Tensor:
x = self.in_proj(x)
for layer, norms in zip(self.layers, self.norms):
x = x + layer(norms(x))
x = self.out_proj(x)
return x
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
mod = vec
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = vec
shift = shift.squeeze(1)
scale = scale.squeeze(1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

271
comfy/ldm/chroma/model.py Normal file
View File

@ -0,0 +1,271 @@
#Original code can be found on: https://github.com/black-forest-labs/flux
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import (
EmbedND,
timestep_embedding,
)
from .layers import (
DoubleStreamBlock,
LastLayer,
SingleStreamBlock,
Approximator,
ChromaModulationOut,
)
@dataclass
class ChromaParams:
in_channels: int
out_channels: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list
theta: int
patch_size: int
qkv_bias: bool
in_dim: int
out_dim: int
hidden_dim: int
n_layers: int
class Chroma(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
params = ChromaParams(**kwargs)
self.params = params
self.patch_size = params.patch_size
self.in_channels = params.in_channels
self.out_channels = params.out_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.in_dim = params.in_dim
self.out_dim = params.out_dim
self.hidden_dim = params.hidden_dim
self.n_layers = params.n_layers
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
# set as nn identity for now, will overwrite it later.
self.distilled_guidance_layer = Approximator(
in_dim=self.in_dim,
hidden_dim=self.hidden_dim,
out_dim=self.out_dim,
n_layers=self.n_layers,
dtype=dtype, device=device, operations=operations
)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
self.skip_mmdit = []
self.skip_dit = []
self.lite = False
def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0):
# This function slices up the modulations tensor which has the following layout:
# single : num_single_blocks * 3 elements
# double_img : num_double_blocks * 6 elements
# double_txt : num_double_blocks * 6 elements
# final : 2 elements
if block_type == "final":
return (tensor[:, -2:-1, :], tensor[:, -1:, :])
single_block_count = self.params.depth_single_blocks
double_block_count = self.params.depth
offset = 3 * idx
if block_type == "single":
return ChromaModulationOut.from_offset(tensor, offset)
# Double block modulations are 6 elements so we double 3 * idx.
offset *= 2
if block_type in {"double_img", "double_txt"}:
# Advance past the single block modulations.
offset += 3 * single_block_count
if block_type == "double_txt":
# Advance past the double block img modulations.
offset += 6 * double_block_count
return (
ChromaModulationOut.from_offset(tensor, offset),
ChromaModulationOut.from_offset(tensor, offset + 3),
)
raise ValueError("Bad block_type")
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
guidance: Tensor = None,
control = None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
# distilled vector guidance
mod_index_length = 344
distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype)
# guidance = guidance *
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
# get all modulation index
modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(img.device, img.dtype)
# we need to broadcast the modulation index here so each batch has all of the index
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
# and we need to broadcast timestep and guidance along too
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype)
# then and only then we could concatenate it together
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype)
mod_vectors = self.distilled_guidance_layer(input_vec)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if i not in self.skip_mmdit:
double_mod = (
self.get_modulations(mod_vectors, "double_img", idx=i),
self.get_modulations(mod_vectors, "double_txt", idx=i),
)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": double_mod,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=double_mod,
pe=pe,
attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
if i not in self.skip_dit:
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": single_mod,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] :, ...] += add
img = img[:, txt.shape[1] :, ...]
final_mod = self.get_modulations(mod_vectors, "final")
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

View File

@ -23,7 +23,6 @@ from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import nn
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
from comfy.ldm.modules.attention import optimized_attention
@ -37,11 +36,11 @@ def apply_rotary_pos_emb(
return t_out
def get_normalization(name: str, channels: int, weight_args={}):
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
if name == "I":
return nn.Identity()
elif name == "R":
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
else:
raise ValueError(f"Normalization {name} not found")
@ -120,15 +119,15 @@ class Attention(nn.Module):
self.to_q = nn.Sequential(
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[0], norm_dim),
get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations),
)
self.to_k = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[1], norm_dim),
get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations),
)
self.to_v = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[2], norm_dim),
get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations),
)
self.to_out = nn.Sequential(

View File

@ -27,8 +27,6 @@ from torchvision import transforms
from enum import Enum
import logging
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
from .blocks import (
FinalLayer,
GeneralDITTransformerBlock,
@ -195,7 +193,7 @@ class GeneralDIT(nn.Module):
if self.affline_emb_norm:
logging.debug("Building affine embedding normalization layer")
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
else:
self.affline_norm = nn.Identity()

View File

@ -13,7 +13,6 @@ from comfy.ldm.modules.attention import optimized_attention
from .layers import (
FeedForward,
PatchEmbed,
RMSNorm,
TimestepEmbedder,
)
@ -90,10 +89,10 @@ class AsymmetricAttention(nn.Module):
# Query and key normalization for stability.
assert qk_norm
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
# Output layers. y features go back down from dim_x -> dim_y.
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)

View File

@ -151,14 +151,3 @@ class PatchEmbed(nn.Module):
x = self.norm(x)
return x
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
self.register_parameter("bias", None)
def forward(self, x):
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)

View File

@ -699,10 +699,13 @@ class HiDreamImageTransformer2DModel(nn.Module):
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
encoder_hidden_states_llama3=None,
image_cond=None,
control = None,
transformer_options = {},
) -> torch.Tensor:
bs, c, h, w = x.shape
if image_cond is not None:
x = torch.cat([x, image_cond], dim=-1)
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
timesteps = t
pooled_embeds = y

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn
import comfy.ops
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from torch.utils import checkpoint
@ -51,7 +51,7 @@ class HunYuanDiTBlock(nn.Module):
if norm_type == "layer":
norm_layer = operations.LayerNorm
elif norm_type == "rms":
norm_layer = RMSNorm
norm_layer = operations.RMSNorm
else:
raise ValueError(f"Unknown norm_type: {norm_type}")

View File

@ -1,7 +1,6 @@
import torch
from torch import nn
import comfy.ldm.modules.attention
from comfy.ldm.genmo.joint_model.layers import RMSNorm
import comfy.ldm.common_dit
from einops import rearrange
import math
@ -262,8 +261,8 @@ class CrossAttention(nn.Module):
self.heads = heads
self.dim_head = dim_head
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)

View File

@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F
import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
@ -64,8 +64,8 @@ class JointAttention(nn.Module):
)
if qk_norm:
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
else:
self.q_norm = self.k_norm = nn.Identity()
@ -242,11 +242,11 @@ class JointTransformerBlock(nn.Module):
operation_settings=operation_settings,
)
self.layer_id = layer_id
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.modulation = modulation
if modulation:
@ -431,7 +431,7 @@ class NextDiT(nn.Module):
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings),
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
cap_feat_dim,
dim,
@ -457,7 +457,7 @@ class NextDiT(nn.Module):
for layer_id in range(n_layers)
]
)
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
assert (dim // n_heads) == sum(axes_dims)

View File

@ -9,7 +9,6 @@ from einops import repeat
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
import comfy.ldm.common_dit
import comfy.model_management
@ -49,8 +48,8 @@ class WanSelfAttention(nn.Module):
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, freqs):
r"""
@ -114,7 +113,7 @@ class WanI2VCrossAttention(WanSelfAttention):
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, context, context_img_len):
r"""
@ -631,6 +630,7 @@ class VaceWanModel(WanModel):
if ii is not None:
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x += c_skip * vace_strength
del c_skip
# head
x = self.head(x, e)

View File

@ -279,6 +279,13 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(key_lora)] = k
key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
if isinstance(model, comfy.model_base.HiDream):
for k in sdk:
if k.startswith("diffusion_model."):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
return key_map

View File

@ -38,6 +38,8 @@ import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.ace.model
import comfy.model_management
import comfy.patcher_extension
@ -786,8 +788,8 @@ class PixArt(BaseModel):
return out
class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
def concat_cond(self, **kwargs):
try:
@ -1104,4 +1106,38 @@ class HiDream(BaseModel):
conditioning_llama3 = kwargs.get("conditioning_llama3", None)
if conditioning_llama3 is not None:
out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3)
image_cond = kwargs.get("concat_latent_image", None)
if image_cond is not None:
out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond))
return out
class Chroma(Flux):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
guidance = kwargs.get("guidance", 0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class ACEStep(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
noise = kwargs.get("noise", None)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
if cross_attn is not None:
out['lyric_token_idx'] = comfy.conds.CONDRegular(conditioning_lyrics)
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
return out

View File

@ -164,7 +164,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if in_key in state_dict_keys:
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys:
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0
@ -174,7 +176,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 10000
dit_config["qkv_bias"] = True
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64
dit_config["out_channels"] = 64
dit_config["in_dim"] = 64
dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
@ -211,10 +222,39 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {}
dit_config["image_model"] = "ltxv"
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
dit_config["attention_head_dim"] = shape[0] // 32
dit_config["cross_attention_dim"] = shape[1]
if metadata is not None and "config" in metadata:
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
return dit_config
if '{}genre_embedder.weight'.format(key_prefix) in state_dict_keys: #ACE-Step model
dit_config = {}
dit_config["audio_model"] = "ace"
dit_config["attention_head_dim"] = 128
dit_config["in_channels"] = 8
dit_config["inner_dim"] = 2560
dit_config["max_height"] = 16
dit_config["max_position"] = 32768
dit_config["max_width"] = 32768
dit_config["mlp_ratio"] = 2.5
dit_config["num_attention_heads"] = 20
dit_config["num_layers"] = 24
dit_config["out_channels"] = 8
dit_config["patch_size"] = [16, 1]
dit_config["rope_theta"] = 1000000.0
dit_config["speaker_embedding_dim"] = 512
dit_config["text_embedding_dim"] = 768
dit_config["ssl_encoder_depths"] = [8, 8]
dit_config["ssl_latent_dims"] = [1024, 768]
dit_config["ssl_names"] = ["mert", "m-hubert"]
dit_config["lyric_encoder_vocab_size"] = 6693
dit_config["lyric_hidden_size"] = 1024
return dit_config
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
patch_size = 2
dit_config = {}

View File

@ -967,15 +967,61 @@ def force_channels_last():
#TODO
return False
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
STREAMS = {}
NUM_STREAMS = 1
if args.async_offload:
NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
stream_counters = {}
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
if NUM_STREAMS <= 1:
return None
if device in STREAMS:
ss = STREAMS[device]
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
stream_counters[device] = stream_counter
return s
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None
def sync_stream(device, stream):
if stream is None:
return
if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
if stream is not None:
with stream:
return weight.to(dtype=dtype, copy=copy)
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
if stream is not None:
with stream:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
else:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r
def cast_to_device(tensor, device, dtype, copy=False):

View File

@ -111,13 +111,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
self.zsnr = zsnr
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
if zsnr:
if self.zsnr:
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
self.set_sigmas(sigmas)

View File

@ -22,6 +22,7 @@ import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
import contextlib
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
@ -37,20 +38,31 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if device is None:
device = input.device
offload_stream = comfy.model_management.get_offload_stream(device)
if offload_stream is not None:
wf_context = offload_stream
else:
wf_context = contextlib.nullcontext()
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = len(s.bias_function) > 0
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
for f in s.bias_function:
bias = f(bias)
with wf_context:
for f in s.bias_function:
bias = f(bias)
has_function = len(s.weight_function) > 0
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
for f in s.weight_function:
weight = f(weight)
with wf_context:
for f in s.weight_function:
weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream)
return weight, bias
class CastWeightBiasOp:
@ -296,10 +308,10 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
else:
scale_input = scale_input.to(input.device)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
if bias is not None:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)

View File

@ -903,7 +903,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation", "er_sde", "seeds_2", "seeds_3"]
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"]
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):

View File

@ -15,6 +15,7 @@ import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import yaml
import math
@ -42,6 +43,7 @@ import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.model_patcher
import comfy.lora
@ -120,6 +122,7 @@ class CLIP:
self.layer_idx = None
self.use_clip_schedule = False
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
self.tokenizer_options = {}
def clone(self):
n = CLIP(no_init=True)
@ -127,6 +130,7 @@ class CLIP:
n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer
n.layer_idx = self.layer_idx
n.tokenizer_options = self.tokenizer_options.copy()
n.use_clip_schedule = self.use_clip_schedule
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n
@ -134,10 +138,18 @@ class CLIP:
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)
def set_tokenizer_option(self, option_name, value):
self.tokenizer_options[option_name] = value
def clip_layer(self, layer_idx):
self.layer_idx = layer_idx
def tokenize(self, text, return_word_ids=False, **kwargs):
tokenizer_options = kwargs.get("tokenizer_options", {})
if len(self.tokenizer_options) > 0:
tokenizer_options = {**self.tokenizer_options, **tokenizer_options}
if len(tokenizer_options) > 0:
kwargs["tokenizer_options"] = tokenizer_options
return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
def add_hooks_to_dict(self, pooled_dict: dict[str]):
@ -270,6 +282,7 @@ class VAE:
self.downscale_index_formula = None
self.upscale_index_formula = None
self.extra_1d_channel = None
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@ -427,6 +440,20 @@ class VAE:
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
self.latent_channels = 8
self.output_channels = 2
self.upscale_ratio = 4096
self.downscale_ratio = 4096
self.latent_dim = 2
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.disable_offload = True
self.extra_1d_channel = 16
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -485,7 +512,13 @@ class VAE:
return output
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
else:
og_shape = samples.shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
@ -505,9 +538,24 @@ class VAE:
samples /= 3.0
return samples
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio
else:
extra_channel_size = self.extra_1d_channel
out_channels = self.latent_channels * extra_channel_size
tile_x = tile_x // extra_channel_size
overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1:
return out
else:
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
@ -532,7 +580,7 @@ class VAE:
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
dims = samples_in.ndim - 2
if dims == 1:
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
@ -599,7 +647,7 @@ class VAE:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1:
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
else:
samples = self.encode_tiled_(pixel_samples)
@ -704,6 +752,8 @@ class CLIPType(Enum):
LUMINA2 = 12
WAN = 13
HIDREAM = 14
CHROMA = 15
ACE = 16
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@ -808,7 +858,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
elif clip_type == CLIPType.PIXART:
elif clip_type == CLIPType.PIXART or clip_type == CLIPType.CHROMA:
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
elif clip_type == CLIPType.WAN:
@ -829,8 +879,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
elif te_model == TEModel.T5_BASE:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
if clip_type == CLIPType.ACE or "spiece_model" in clip_data[0]:
clip_target.clip = comfy.text_encoders.ace.AceT5Model
clip_target.tokenizer = comfy.text_encoders.ace.AceT5Tokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
else:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer

View File

@ -457,13 +457,14 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}, tokenizer_args={}):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
self.min_length = min_length
self.end_token = None
self.min_padding = min_padding
empty = self.tokenizer('')["input_ids"]
self.tokenizer_adds_end_token = has_end_token
@ -518,13 +519,15 @@ class SDTokenizer:
return (embed, leftover)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
'''
Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors.
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
Returned list has the dimensions NxM where M is the input size of CLIP
'''
min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
text = escape_important(text)
parsed_weights = token_weights(text, 1.0)
@ -603,10 +606,12 @@ class SDTokenizer:
#fill last batch
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
if min_padding is not None:
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
if self.pad_to_max_length and len(batch) < self.max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length:
batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch)))
if min_length is not None and len(batch) < min_length:
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
@ -634,7 +639,7 @@ class SD1Tokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):

View File

@ -28,8 +28,8 @@ class SDXLTokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):

View File

@ -17,6 +17,7 @@ import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.wan
import comfy.text_encoders.ace
from . import supported_models_base
from . import latent_formats
@ -785,6 +786,10 @@ class LTXV(supported_models_base.BASE):
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("cross_attention_dim", 2048) / 2048) * 5.5
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXV(self, device=device)
return out
@ -993,6 +998,10 @@ class WAN21_Vace(WAN21_T2V):
"model_type": "vace",
}
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = 1.2 * self.memory_usage_factor
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
return out
@ -1064,7 +1073,62 @@ class HiDream(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return None # TODO
class Chroma(supported_models_base.BASE):
unet_config = {
"image_model": "chroma",
}
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
unet_extra_config = {
}
sampling_settings = {
"multiplier": 1.0,
}
latent_format = comfy.latent_formats.Flux
memory_usage_factor = 3.2
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Chroma(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
class ACEStep(supported_models_base.BASE):
unet_config = {
"audio_model": "ace",
}
unet_extra_config = {
}
sampling_settings = {
"shift": 3.0,
}
latent_format = comfy.latent_formats.ACEAudio
memory_usage_factor = 0.5
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.ACEStep(self, device=device)
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
models += [SVD_img2vid]

153
comfy/text_encoders/ace.py Normal file
View File

@ -0,0 +1,153 @@
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.t5
import os
import re
import torch
import logging
from tokenizers import Tokenizer
from .ace_text_cleaners import multilingual_cleaners, japanese_to_romaji
SUPPORT_LANGUAGES = {
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
"pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
"nl": 297, "ar": 5022, "zh": 5023, "ja": 5412, "hu": 5753,
"ko": 6152, "hi": 6680
}
structure_pattern = re.compile(r"\[.*?\]")
DEFAULT_VOCAB_FILE = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
class VoiceBpeTokenizer:
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
self.tokenizer = None
if vocab_file is not None:
self.tokenizer = Tokenizer.from_file(vocab_file)
def preprocess_text(self, txt, lang):
txt = multilingual_cleaners(txt, lang)
return txt
def encode(self, txt, lang='en'):
# lang = lang.split("-")[0] # remove the region
# self.check_input_length(txt, lang)
txt = self.preprocess_text(txt, lang)
lang = "zh-cn" if lang == "zh" else lang
txt = f"[{lang}]{txt}"
txt = txt.replace(" ", "[SPACE]")
return self.tokenizer.encode(txt).ids
def get_lang(self, line):
if line.startswith("[") and line[3:4] == ']':
lang = line[1:3].lower()
if lang in SUPPORT_LANGUAGES:
return lang, line[4:]
return "en", line
def __call__(self, string):
lines = string.split("\n")
lyric_token_idx = [261]
for line in lines:
line = line.strip()
if not line:
lyric_token_idx += [2]
continue
lang, line = self.get_lang(line)
if lang not in SUPPORT_LANGUAGES:
lang = "en"
if "zh" in lang:
lang = "zh"
if "spa" in lang:
lang = "es"
try:
line_out = japanese_to_romaji(line)
if line_out != line:
lang = "ja"
line = line_out
except:
pass
try:
if structure_pattern.match(line):
token_idx = self.encode(line, "en")
else:
token_idx = self.encode(line, lang)
lyric_token_idx = lyric_token_idx + token_idx + [2]
except Exception as e:
logging.warning("tokenize error {} for line {} major_language {}".format(e, line, lang))
return {"input_ids": lyric_token_idx}
@staticmethod
def from_pretrained(path, **kwargs):
return VoiceBpeTokenizer(path, **kwargs)
def get_vocab(self):
return {}
class UMT5BaseModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_base.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=False, model_options=model_options)
class UMT5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=768, embedding_key='umt5base', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LyricsTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='lyrics', tokenizer_class=VoiceBpeTokenizer, has_start_token=True, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=2, has_end_token=False, tokenizer_data=tokenizer_data)
class AceT5Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.voicebpe = LyricsTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.umt5base = UMT5BaseTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["lyrics"] = self.voicebpe.tokenize_with_weights(kwargs.get("lyrics", ""), return_word_ids, **kwargs)
out["umt5base"] = self.umt5base.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
return self.umt5base.untokenize(token_weight_pair)
def state_dict(self):
return self.umt5base.state_dict()
class AceT5Model(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__()
self.umt5base = UMT5BaseModel(device=device, dtype=dtype, model_options=model_options)
self.dtypes = set()
if dtype is not None:
self.dtypes.add(dtype)
def set_clip_options(self, options):
self.umt5base.set_clip_options(options)
def reset_clip_options(self):
self.umt5base.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_umt5base = token_weight_pairs["umt5base"]
token_weight_pairs_lyrics = token_weight_pairs["lyrics"]
t5_out, t5_pooled = self.umt5base.encode_token_weights(token_weight_pairs_umt5base)
lyrics_embeds = torch.tensor(list(map(lambda a: a[0], token_weight_pairs_lyrics[0]))).unsqueeze(0)
return t5_out, None, {"conditioning_lyrics": lyrics_embeds}
def load_sd(self, sd):
return self.umt5base.load_sd(sd)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,395 @@
# basic text cleaners for the ACE step model
# I didn't copy the ones from the reference code because I didn't want to deal with the dependencies
# TODO: more languages than english?
import re
def japanese_to_romaji(japanese_text):
"""
Convert Japanese hiragana and katakana to romaji (Latin alphabet representation).
Args:
japanese_text (str): Text containing hiragana and/or katakana characters
Returns:
str: The romaji (Latin alphabet) equivalent
"""
# Dictionary mapping kana characters to their romaji equivalents
kana_map = {
# Katakana characters
'': 'a', '': 'i', '': 'u', '': 'e', '': 'o',
'': 'ka', '': 'ki', '': 'ku', '': 'ke', '': 'ko',
'': 'sa', '': 'shi', '': 'su', '': 'se', '': 'so',
'': 'ta', '': 'chi', '': 'tsu', '': 'te', '': 'to',
'': 'na', '': 'ni', '': 'nu', '': 'ne', '': 'no',
'': 'ha', '': 'hi', '': 'fu', '': 'he', '': 'ho',
'': 'ma', '': 'mi', '': 'mu', '': 'me', '': 'mo',
'': 'ya', '': 'yu', '': 'yo',
'': 'ra', '': 'ri', '': 'ru', '': 're', '': 'ro',
'': 'wa', '': 'wo', '': 'n',
# Katakana voiced consonants
'': 'ga', '': 'gi', '': 'gu', '': 'ge', '': 'go',
'': 'za', '': 'ji', '': 'zu', '': 'ze', '': 'zo',
'': 'da', '': 'ji', '': 'zu', '': 'de', '': 'do',
'': 'ba', '': 'bi', '': 'bu', '': 'be', '': 'bo',
'': 'pa', '': 'pi', '': 'pu', '': 'pe', '': 'po',
# Katakana combinations
'キャ': 'kya', 'キュ': 'kyu', 'キョ': 'kyo',
'シャ': 'sha', 'シュ': 'shu', 'ショ': 'sho',
'チャ': 'cha', 'チュ': 'chu', 'チョ': 'cho',
'ニャ': 'nya', 'ニュ': 'nyu', 'ニョ': 'nyo',
'ヒャ': 'hya', 'ヒュ': 'hyu', 'ヒョ': 'hyo',
'ミャ': 'mya', 'ミュ': 'myu', 'ミョ': 'myo',
'リャ': 'rya', 'リュ': 'ryu', 'リョ': 'ryo',
'ギャ': 'gya', 'ギュ': 'gyu', 'ギョ': 'gyo',
'ジャ': 'ja', 'ジュ': 'ju', 'ジョ': 'jo',
'ビャ': 'bya', 'ビュ': 'byu', 'ビョ': 'byo',
'ピャ': 'pya', 'ピュ': 'pyu', 'ピョ': 'pyo',
# Katakana small characters and special cases
'': '', # Small tsu (doubles the following consonant)
'': 'ya', '': 'yu', '': 'yo',
# Katakana extras
'': 'vu', 'ファ': 'fa', 'フィ': 'fi', 'フェ': 'fe', 'フォ': 'fo',
'ウィ': 'wi', 'ウェ': 'we', 'ウォ': 'wo',
# Hiragana characters
'': 'a', '': 'i', '': 'u', '': 'e', '': 'o',
'': 'ka', '': 'ki', '': 'ku', '': 'ke', '': 'ko',
'': 'sa', '': 'shi', '': 'su', '': 'se', '': 'so',
'': 'ta', '': 'chi', '': 'tsu', '': 'te', '': 'to',
'': 'na', '': 'ni', '': 'nu', '': 'ne', '': 'no',
'': 'ha', '': 'hi', '': 'fu', '': 'he', '': 'ho',
'': 'ma', '': 'mi', '': 'mu', '': 'me', '': 'mo',
'': 'ya', '': 'yu', '': 'yo',
'': 'ra', '': 'ri', '': 'ru', '': 're', '': 'ro',
'': 'wa', '': 'wo', '': 'n',
# Hiragana voiced consonants
'': 'ga', '': 'gi', '': 'gu', '': 'ge', '': 'go',
'': 'za', '': 'ji', '': 'zu', '': 'ze', '': 'zo',
'': 'da', '': 'ji', '': 'zu', '': 'de', '': 'do',
'': 'ba', '': 'bi', '': 'bu', '': 'be', '': 'bo',
'': 'pa', '': 'pi', '': 'pu', '': 'pe', '': 'po',
# Hiragana combinations
'きゃ': 'kya', 'きゅ': 'kyu', 'きょ': 'kyo',
'しゃ': 'sha', 'しゅ': 'shu', 'しょ': 'sho',
'ちゃ': 'cha', 'ちゅ': 'chu', 'ちょ': 'cho',
'にゃ': 'nya', 'にゅ': 'nyu', 'にょ': 'nyo',
'ひゃ': 'hya', 'ひゅ': 'hyu', 'ひょ': 'hyo',
'みゃ': 'mya', 'みゅ': 'myu', 'みょ': 'myo',
'りゃ': 'rya', 'りゅ': 'ryu', 'りょ': 'ryo',
'ぎゃ': 'gya', 'ぎゅ': 'gyu', 'ぎょ': 'gyo',
'じゃ': 'ja', 'じゅ': 'ju', 'じょ': 'jo',
'びゃ': 'bya', 'びゅ': 'byu', 'びょ': 'byo',
'ぴゃ': 'pya', 'ぴゅ': 'pyu', 'ぴょ': 'pyo',
# Hiragana small characters and special cases
'': '', # Small tsu (doubles the following consonant)
'': 'ya', '': 'yu', '': 'yo',
# Common punctuation and spaces
' ': ' ', # Japanese space
'': ', ', '': '. ',
}
result = []
i = 0
while i < len(japanese_text):
# Check for small tsu (doubling the following consonant)
if i < len(japanese_text) - 1 and (japanese_text[i] == '' or japanese_text[i] == ''):
if i < len(japanese_text) - 1 and japanese_text[i+1] in kana_map:
next_romaji = kana_map[japanese_text[i+1]]
if next_romaji and next_romaji[0] not in 'aiueon':
result.append(next_romaji[0]) # Double the consonant
i += 1
continue
# Check for combinations with small ya, yu, yo
if i < len(japanese_text) - 1 and japanese_text[i+1] in ('', '', '', '', '', ''):
combo = japanese_text[i:i+2]
if combo in kana_map:
result.append(kana_map[combo])
i += 2
continue
# Regular character
if japanese_text[i] in kana_map:
result.append(kana_map[japanese_text[i]])
else:
# If it's not in our map, keep it as is (might be kanji, romaji, etc.)
result.append(japanese_text[i])
i += 1
return ''.join(result)
def number_to_text(num, ordinal=False):
"""
Convert a number (int or float) to its text representation.
Args:
num: The number to convert
Returns:
str: Text representation of the number
"""
if not isinstance(num, (int, float)):
return "Input must be a number"
# Handle special case of zero
if num == 0:
return "zero"
# Handle negative numbers
negative = num < 0
num = abs(num)
# Handle floats
if isinstance(num, float):
# Split into integer and decimal parts
int_part = int(num)
# Convert both parts
int_text = _int_to_text(int_part)
# Handle decimal part (convert to string and remove '0.')
decimal_str = str(num).split('.')[1]
decimal_text = " point " + " ".join(_digit_to_text(int(digit)) for digit in decimal_str)
result = int_text + decimal_text
else:
# Handle integers
result = _int_to_text(num)
# Add 'negative' prefix for negative numbers
if negative:
result = "negative " + result
return result
def _int_to_text(num):
"""Helper function to convert an integer to text"""
ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
"ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen",
"seventeen", "eighteen", "nineteen"]
tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]
if num < 20:
return ones[num]
if num < 100:
return tens[num // 10] + (" " + ones[num % 10] if num % 10 != 0 else "")
if num < 1000:
return ones[num // 100] + " hundred" + (" " + _int_to_text(num % 100) if num % 100 != 0 else "")
if num < 1000000:
return _int_to_text(num // 1000) + " thousand" + (" " + _int_to_text(num % 1000) if num % 1000 != 0 else "")
if num < 1000000000:
return _int_to_text(num // 1000000) + " million" + (" " + _int_to_text(num % 1000000) if num % 1000000 != 0 else "")
return _int_to_text(num // 1000000000) + " billion" + (" " + _int_to_text(num % 1000000000) if num % 1000000000 != 0 else "")
def _digit_to_text(digit):
"""Convert a single digit to text"""
digits = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
return digits[digit]
_whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = {
"en": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
],
}
def expand_abbreviations_multilingual(text, lang="en"):
for regex, replacement in _abbreviations[lang]:
text = re.sub(regex, replacement, text)
return text
_symbols_multilingual = {
"en": [
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
("&", " and "),
("@", " at "),
("%", " percent "),
("#", " hash "),
("$", " dollar "),
("£", " pound "),
("°", " degree "),
]
],
}
def expand_symbols_multilingual(text, lang="en"):
for regex, replacement in _symbols_multilingual[lang]:
text = re.sub(regex, replacement, text)
text = text.replace(" ", " ") # Ensure there are no double spaces
return text.strip()
_ordinal_re = {
"en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
}
_number_re = re.compile(r"[0-9]+")
_currency_re = {
"USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
"GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
"EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
}
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
def _remove_commas(m):
text = m.group(0)
if "," in text:
text = text.replace(",", "")
return text
def _remove_dots(m):
text = m.group(0)
if "." in text:
text = text.replace(".", "")
return text
def _expand_decimal_point(m, lang="en"):
amount = m.group(1).replace(",", ".")
return number_to_text(float(amount))
def _expand_currency(m, lang="en", currency="USD"):
amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
full_amount = number_to_text(amount)
and_equivalents = {
"en": ", ",
"es": " con ",
"fr": " et ",
"de": " und ",
"pt": " e ",
"it": " e ",
"pl": ", ",
"cs": ", ",
"ru": ", ",
"nl": ", ",
"ar": ", ",
"tr": ", ",
"hu": ", ",
"ko": ", ",
}
if amount.is_integer():
last_and = full_amount.rfind(and_equivalents[lang])
if last_and != -1:
full_amount = full_amount[:last_and]
return full_amount
def _expand_ordinal(m, lang="en"):
return number_to_text(int(m.group(1)), ordinal=True)
def _expand_number(m, lang="en"):
return number_to_text(int(m.group(0)))
def expand_numbers_multilingual(text, lang="en"):
if lang in ["en", "ru"]:
text = re.sub(_comma_number_re, _remove_commas, text)
else:
text = re.sub(_dot_number_re, _remove_dots, text)
try:
text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
except:
pass
text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
return text
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text)
def multilingual_cleaners(text, lang):
text = text.replace('"', "")
if lang == "tr":
text = text.replace("İ", "i")
text = text.replace("Ö", "ö")
text = text.replace("Ü", "ü")
text = lowercase(text)
try:
text = expand_numbers_multilingual(text, lang)
except:
pass
try:
text = expand_abbreviations_multilingual(text, lang)
except:
pass
try:
text = expand_symbols_multilingual(text, lang=lang)
except:
pass
text = collapse_whitespace(text)
return text
def basic_cleaners(text):
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text)
text = collapse_whitespace(text)
return text

View File

@ -19,8 +19,8 @@ class FluxTokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):

View File

@ -16,11 +16,11 @@ class HiDreamTokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids)
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):

View File

@ -49,13 +49,13 @@ class HunyuanVideoTokenizer:
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids)
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids, **kwargs)
embed_count = 0
for r in llama_text_tokens:
for i in range(len(r)):

View File

@ -41,8 +41,8 @@ class HyditTokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids, **kwargs)
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):

View File

@ -45,9 +45,9 @@ class SD3Tokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):

View File

@ -0,0 +1,22 @@
{
"d_ff": 2048,
"d_kv": 64,
"d_model": 768,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "umt5",
"num_decoder_layers": 12,
"num_heads": 12,
"num_layers": 12,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 256384
}

View File

@ -28,6 +28,9 @@ import logging
import itertools
from torch.nn.functional import interpolate
from einops import rearrange
from comfy.cli_args import args
MMAP_TORCH_FILES = args.mmap_torch_files
ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
@ -67,8 +70,12 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
raise e
else:
torch_args = {}
if MMAP_TORCH_FILES:
torch_args["mmap"] = True
if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "global_step" in pl_sd:

View File

@ -24,7 +24,7 @@ class BOFTAdapter(WeightAdapterBase):
) -> Optional["BOFTAdapter"]:
if loaded_keys is None:
loaded_keys = set()
blocks_name = "{}.boft_blocks".format(x)
blocks_name = "{}.oft_blocks".format(x)
rescale_name = "{}.rescale".format(x)
blocks = None
@ -32,17 +32,18 @@ class BOFTAdapter(WeightAdapterBase):
blocks = lora[blocks_name]
if blocks.ndim == 4:
loaded_keys.add(blocks_name)
else:
blocks = None
if blocks is None:
return None
rescale = None
if rescale_name in lora.keys():
rescale = lora[rescale_name]
loaded_keys.add(rescale_name)
if blocks is not None:
weights = (blocks, rescale, alpha, dora_scale)
return cls(loaded_keys, weights)
else:
return None
weights = (blocks, rescale, alpha, dora_scale)
return cls(loaded_keys, weights)
def calculate_weight(
self,
@ -71,7 +72,7 @@ class BOFTAdapter(WeightAdapterBase):
# Get r
I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype)
# for Q = -Q^T
q = blocks - blocks.transpose(1, 2)
q = blocks - blocks.transpose(-1, -2)
normed_q = q
if alpha > 0: # alpha in boft/bboft is for constraint
q_norm = torch.norm(q) + 1e-8
@ -79,9 +80,8 @@ class BOFTAdapter(WeightAdapterBase):
normed_q = q * alpha / q_norm
# use float() to prevent unsupported type in .inverse()
r = (I + normed_q) @ (I - normed_q).float().inverse()
r = r.to(original_weight)
inp = org = original_weight
r = r.to(weight)
inp = org = weight
r_b = boft_b//2
for i in range(boft_m):
@ -91,14 +91,14 @@ class BOFTAdapter(WeightAdapterBase):
if strength != 1:
bi = bi * strength + (1-strength) * I
inp = (
inp.unflatten(-1, (-1, g, k))
.transpose(-2, -1)
.flatten(-3)
.unflatten(-1, (-1, boft_b))
inp.unflatten(0, (-1, g, k))
.transpose(1, 2)
.flatten(0, 2)
.unflatten(0, (-1, boft_b))
)
inp = torch.einsum("b n m, b n ... -> b m ...", inp, bi)
inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp)
inp = (
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2)
)
if rescale is not None:
@ -109,7 +109,7 @@ class BOFTAdapter(WeightAdapterBase):
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
weight += function((strength * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight

View File

@ -32,17 +32,18 @@ class OFTAdapter(WeightAdapterBase):
blocks = lora[blocks_name]
if blocks.ndim == 3:
loaded_keys.add(blocks_name)
else:
blocks = None
if blocks is None:
return None
rescale = None
if rescale_name in lora.keys():
rescale = lora[rescale_name]
loaded_keys.add(rescale_name)
if blocks is not None:
weights = (blocks, rescale, alpha, dora_scale)
return cls(loaded_keys, weights)
else:
return None
weights = (blocks, rescale, alpha, dora_scale)
return cls(loaded_keys, weights)
def calculate_weight(
self,
@ -79,16 +80,17 @@ class OFTAdapter(WeightAdapterBase):
normed_q = q * alpha / q_norm
# use float() to prevent unsupported type in .inverse()
r = (I + normed_q) @ (I - normed_q).float().inverse()
r = r.to(original_weight)
r = r.to(weight)
_, *shape = weight.shape
lora_diff = torch.einsum(
"k n m, k n ... -> k m ...",
(r * strength) - strength * I,
original_weight,
)
weight.view(block_num, block_size, *shape),
).view(-1, *shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
weight += function((strength * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight

View File

@ -0,0 +1,8 @@
from .basic_types import ImageInput, AudioInput
from .video_types import VideoInput
__all__ = [
"ImageInput",
"AudioInput",
"VideoInput",
]

View File

@ -0,0 +1,20 @@
import torch
from typing import TypedDict
ImageInput = torch.Tensor
"""
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
"""
class AudioInput(TypedDict):
"""
TypedDict representing audio input.
"""
waveform: torch.Tensor
"""
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
"""
sample_rate: int

View File

@ -0,0 +1,45 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
"""
Abstract base class for video input types.
"""
@abstractmethod
def get_components(self) -> VideoComponents:
"""
Abstract method to get the video components (images, audio, and frame rate).
Returns:
VideoComponents containing images, audio, and frame rate
"""
pass
@abstractmethod
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
"""
Abstract method to save the video input to a file.
"""
pass
# Provide a default implementation, but subclasses can provide optimized versions
# if possible.
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.
Returns:
Tuple of (width, height)
"""
components = self.get_components()
return components.images.shape[2], components.images.shape[1]

View File

@ -0,0 +1,7 @@
from .video_types import VideoFromFile, VideoFromComponents
__all__ = [
# Implementations
"VideoFromFile",
"VideoFromComponents",
]

View File

@ -0,0 +1,271 @@
from __future__ import annotations
from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
from comfy_api.input import AudioInput
import av
import io
import json
import numpy as np
import torch
from comfy_api.input import VideoInput
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None:
"""
A container's `format` may be a comma-separated list of formats.
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
However, writing to a file/stream with `av.open` requires a single format,
or `None` to auto-detect.
"""
if not container_format:
return None # Auto-detect
if "," not in container_format:
return container_format
formats = container_format.split(",")
return formats[0]
def get_open_write_kwargs(
dest: str | io.BytesIO, container_format: str, to_format: str | None
) -> dict:
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
open_kwargs = {
"mode": "w",
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
"options": {"movflags": "use_metadata_tags"},
}
is_write_to_buffer = isinstance(dest, io.BytesIO)
if is_write_to_buffer:
# Set output format explicitly, since it cannot be inferred from file extension
if to_format == VideoContainer.AUTO:
to_format = container_format.lower()
elif isinstance(to_format, str):
to_format = to_format.lower()
open_kwargs["format"] = container_to_output_format(to_format)
return open_kwargs
class VideoFromFile(VideoInput):
"""
Class representing video input from a file.
"""
def __init__(self, file: str | io.BytesIO):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.
Returns:
Tuple of (width, height)
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
for stream in container.streams:
if stream.type == 'video':
assert isinstance(stream, av.VideoStream)
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
def get_components_internal(self, container: InputContainer) -> VideoComponents:
# Get video frames
frames = []
for frame in container.decode(video=0):
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
# Get frame rate
video_stream = next(s for s in container.streams if s.type == 'video')
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
# Get audio if available
audio = None
try:
container.seek(0) # Reset the container to the beginning
for stream in container.streams:
if stream.type != 'audio':
continue
assert isinstance(stream, av.AudioStream)
audio_frames = []
for packet in container.demux(stream):
for frame in packet.decode():
assert isinstance(frame, av.AudioFrame)
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
})
except StopIteration:
pass # No audio stream
metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
return self.get_components_internal(container)
raise ValueError(f"No video stream found in file '{self.__file}'")
def save_to(
self,
path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
container_format = container.format.name
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
reuse_streams = True
if format != VideoContainer.AUTO and format not in container_format.split(","):
reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False
if not reuse_streams:
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
path,
format=format,
codec=codec,
metadata=metadata
)
streams = container.streams
open_kwargs = get_open_write_kwargs(path, container_format, format)
with av.open(path, **open_kwargs) as output_container:
# Copy over the original metadata
for key, value in container.metadata.items():
if metadata is None or key not in metadata:
output_container.metadata[key] = value
# Add our new metadata
if metadata is not None:
for key, value in metadata.items():
if isinstance(value, str):
output_container.metadata[key] = value
else:
output_container.metadata[key] = json.dumps(value)
# Add streams to the new container
stream_map = {}
for stream in streams:
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
stream_map[stream] = out_stream
# Write packets to the new container
for packet in container.demux():
if packet.stream in stream_map and packet.dts is not None:
packet.stream = stream_map[packet.stream]
output_container.mux(packet)
class VideoFromComponents(VideoInput):
"""
Class representing video input from tensors.
"""
def __init__(self, components: VideoComponents):
self.__components = components
def get_components(self) -> VideoComponents:
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
frame_rate=self.__components.frame_rate
)
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
# Add metadata before writing any streams
if metadata is not None:
for key, value in metadata.items():
output.metadata[key] = json.dumps(value)
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
# Create a video stream
video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = self.__components.images.shape[2]
video_stream.height = self.__components.images.shape[1]
video_stream.pix_fmt = 'yuv420p'
# Create an audio stream
audio_sample_rate = 1
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
audio_stream.sample_rate = audio_sample_rate
audio_stream.format = 'fltp'
# Encode video
for i, frame in enumerate(self.__components.images):
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
packet = video_stream.encode(frame)
output.mux(packet)
# Flush video
packet = video_stream.encode(None)
output.mux(packet)
if audio_stream and self.__components.audio:
# Encode audio
samples_per_frame = int(audio_sample_rate / frame_rate)
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
for i in range(num_frames):
start = i * samples_per_frame
end = start + samples_per_frame
# TODO(Feature) - Add support for stereo audio
chunk = (
self.__components.audio["waveform"][0, 0, start:end]
.unsqueeze(0)
.contiguous()
.numpy()
)
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
audio_frame.sample_rate = audio_sample_rate
audio_frame.pts = i * samples_per_frame
for packet in audio_stream.encode(audio_frame):
output.mux(packet)
# Flush audio
for packet in audio_stream.encode(None):
output.mux(packet)

View File

@ -0,0 +1,8 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents
__all__ = [
# Utility Types
"VideoContainer",
"VideoCodec",
"VideoComponents",
]

View File

@ -0,0 +1,51 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from fractions import Fraction
from typing import Optional
from comfy_api.input import ImageInput, AudioInput
class VideoCodec(str, Enum):
AUTO = "auto"
H264 = "h264"
@classmethod
def as_input(cls) -> list[str]:
"""
Returns a list of codec names that can be used as node input.
"""
return [member.value for member in cls]
class VideoContainer(str, Enum):
AUTO = "auto"
MP4 = "mp4"
@classmethod
def as_input(cls) -> list[str]:
"""
Returns a list of container names that can be used as node input.
"""
return [member.value for member in cls]
@classmethod
def get_extension(cls, value) -> str:
"""
Returns the file extension for the container.
"""
if isinstance(value, str):
value = cls(value)
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
return "mp4"
return ""
@dataclass
class VideoComponents:
"""
Dataclass representing the components of a video.
"""
images: ImageInput
frame_rate: Fraction
audio: Optional[AudioInput] = None
metadata: Optional[dict] = None

41
comfy_api_nodes/README.md Normal file
View File

@ -0,0 +1,41 @@
# ComfyUI API Nodes
## Introduction
Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview#api-nodes).
## Development
While developing, you should be testing against the Staging environment. To test against staging:
**Install ComfyUI_frontend**
Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to start the frontend server. By default, it will connect to Staging authentication.
> **Hint:** If you use --front-end-version argument for ComfyUI, it will use production authentication.
```bash
python run main.py --comfy-api-base https://stagingapi.comfy.org
```
API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes.
### Redocly Instructions
**Tip**
When developing locally, use the `redocly-dev.yaml` file to generate pydantic models. This lets you use stubs for APIs that are not marked `Released` yet.
Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging.
```bash
# Download the OpenAPI file from prod server.
curl -o openapi.yaml https://stagingapi.comfy.org/openapi
# Filter out unneeded API definitions.
npm install -g @redocly/cli
redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly-dev.yaml --remove-unused-components
# Generate the pydantic datamodels for validation.
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
```

View File

View File

@ -0,0 +1,576 @@
from __future__ import annotations
import io
import logging
from typing import Optional
from comfy.utils import common_upscale
from comfy_api.input_impl import VideoFromFile
from comfy_api.util import VideoContainer, VideoCodec
from comfy_api.input.video_types import VideoInput
from comfy_api.input.basic_types import AudioInput
from comfy_api_nodes.apis.client import (
ApiClient,
ApiEndpoint,
HttpMethod,
SynchronousOperation,
UploadRequest,
UploadResponse,
)
import numpy as np
from PIL import Image
import requests
import torch
import math
import base64
import uuid
from io import BytesIO
import av
def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output.
Args:
video_url: The URL of the video to download.
Returns:
A Comfy node `VIDEO` output.
"""
video_io = download_url_to_bytesio(video_url, timeout)
if video_io is None:
error_msg = f"Failed to download video from {video_url}"
logging.error(error_msg)
raise ValueError(error_msg)
return VideoFromFile(video_io)
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
"""Downscale input image tensor to roughly the specified total pixels."""
samples = image.movedim(-1, 1)
total = int(total_pixels)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
if scale_by >= 1:
return image
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = common_upscale(samples, width, height, "lanczos", "disabled")
s = s.movedim(1, -1)
return s
def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
"""Validates and casts a response to a torch.Tensor.
Args:
response: The response to validate and cast.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
ValueError: If the response is not valid.
"""
# validate raw JSON response
data = response.data
if not data or len(data) == 0:
raise ValueError("No images returned from API endpoint")
# Initialize list to store image tensors
image_tensors: list[torch.Tensor] = []
# Process each image in the data array
for image_data in data:
image_url = image_data.url
b64_data = image_data.b64_json
if not image_url and not b64_data:
raise ValueError("No image was generated in the response")
if b64_data:
img_data = base64.b64decode(b64_data)
img = Image.open(io.BytesIO(img_data))
elif image_url:
img_response = requests.get(image_url, timeout=timeout)
if img_response.status_code != 200:
raise ValueError("Failed to download the image")
img = Image.open(io.BytesIO(img_response.content))
img = img.convert("RGBA")
# Convert to numpy array, normalize to float32 between 0 and 1
img_array = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array)
# Add to list of tensors
image_tensors.append(img_tensor)
return torch.stack(image_tensors, dim=0)
def validate_aspect_ratio(
aspect_ratio: str,
minimum_ratio: float,
maximum_ratio: float,
minimum_ratio_str: str,
maximum_ratio_str: str,
) -> float:
"""Validates and casts an aspect ratio string to a float.
Args:
aspect_ratio: The aspect ratio string to validate.
minimum_ratio: The minimum aspect ratio.
maximum_ratio: The maximum aspect ratio.
minimum_ratio_str: The minimum aspect ratio string.
maximum_ratio_str: The maximum aspect ratio string.
Returns:
The validated and cast aspect ratio.
Raises:
Exception: If the aspect ratio is not valid.
"""
# get ratio values
numbers = aspect_ratio.split(":")
if len(numbers) != 2:
raise TypeError(
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
)
try:
numerator = int(numbers[0])
denominator = int(numbers[1])
except ValueError as exc:
raise TypeError(
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
) from exc
calculated_ratio = numerator / denominator
# if not close to minimum and maximum, check bounds
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
calculated_ratio, maximum_ratio
):
if calculated_ratio < minimum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
elif calculated_ratio > maximum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
return aspect_ratio
def mimetype_to_extension(mime_type: str) -> str:
"""Converts a MIME type to a file extension."""
return mime_type.split("/")[-1].lower()
def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO.
Args:
url: The URL to download.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
BytesIO object containing the downloaded content.
"""
response = requests.get(url, stream=True, timeout=timeout)
response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(response.content)
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
"""Converts image data from BytesIO to a torch.Tensor.
Args:
image_bytesio: BytesIO object containing the image data.
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
PIL.UnidentifiedImageError: If the image data cannot be identified.
ValueError: If the specified mode is invalid.
"""
image = Image.open(image_bytesio)
image = image.convert(mode)
image_array = np.array(image).astype(np.float32) / 255.0
return torch.from_numpy(image_array).unsqueeze(0)
def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
image_bytesio = download_url_to_bytesio(url, timeout)
return bytesio_to_image_tensor(image_bytesio)
def process_image_response(response: requests.Response) -> torch.Tensor:
"""Uses content from a Response object and converts it to a torch.Tensor"""
return bytesio_to_image_tensor(BytesIO(response.content))
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
if len(image.shape) > 3:
image = image[0]
# TODO: remove alpha if not allowed and present
input_tensor = image.cpu()
input_tensor = downscale_image_tensor(
input_tensor.unsqueeze(0), total_pixels=total_pixels
).squeeze()
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
return img
def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
"""Converts a PIL Image to a BytesIO object."""
if not mime_type:
mime_type = "image/png"
img_byte_arr = io.BytesIO()
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
pil_format = mime_type.split("/")[-1].upper()
if pil_format == "JPG":
pil_format = "JPEG"
img.save(img_byte_arr, format=pil_format)
img_byte_arr.seek(0)
return img_byte_arr
def tensor_to_bytesio(
image: torch.Tensor,
name: Optional[str] = None,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> BytesIO:
"""Converts a torch.Tensor image to a named BytesIO object.
Args:
image: Input torch.Tensor image.
name: Optional filename for the BytesIO object.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Named BytesIO object containing the image data.
"""
if not mime_type:
mime_type = "image/png"
pil_image = _tensor_to_pil(image, total_pixels=total_pixels)
img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type)
img_binary.name = (
f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
)
return img_binary
def tensor_to_base64_string(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Base64 encoded string of the image.
"""
pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels)
img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type)
img_bytes = img_byte_arr.getvalue()
# Encode bytes to base64 string
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
return base64_encoded_string
def tensor_to_data_uri(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Converts a tensor image to a Data URI string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
Returns:
Data URI string (e.g., 'data:image/png;base64,...').
"""
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
return f"data:{mime_type};base64,{base64_string}"
def upload_file_to_comfyapi(
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: str,
auth_kwargs: Optional[dict[str,str]] = None,
) -> str:
"""
Uploads a single file to ComfyUI API and returns its download URL.
Args:
file_bytes_io: BytesIO object containing the file data.
filename: The filename of the file.
upload_mime_type: MIME type of the file.
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded file.
"""
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/customers/storage",
method=HttpMethod.POST,
request_model=UploadRequest,
response_model=UploadResponse,
),
request=request_object,
auth_kwargs=auth_kwargs,
)
response: UploadResponse = operation.execute()
upload_response = ApiClient.upload_file(
response.upload_url, file_bytes_io, content_type=upload_mime_type
)
upload_response.raise_for_status()
return response.download_url
def upload_video_to_comfyapi(
video: VideoInput,
auth_kwargs: Optional[dict[str,str]] = None,
container: VideoContainer = VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264,
max_duration: Optional[int] = None,
) -> str:
"""
Uploads a single video to ComfyUI API and returns its download URL.
Uses the specified container and codec for saving the video before upload.
Args:
video: VideoInput object (Comfy VIDEO type).
auth_kwargs: Optional authentication token(s).
container: The video container format to use (default: MP4).
codec: The video codec to use (default: H264).
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
Returns:
The download URL for the uploaded video file.
"""
if max_duration is not None:
try:
actual_duration = video.duration_seconds
if actual_duration is not None and actual_duration > max_duration:
raise ValueError(
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
)
except Exception as e:
logging.error(f"Error getting video duration: {e}")
raise ValueError(f"Could not verify video duration from source: {e}") from e
upload_mime_type = f"video/{container.value.lower()}"
filename = f"uploaded_video.{container.value.lower()}"
# Convert VideoInput to BytesIO using specified container/codec
video_bytes_io = io.BytesIO()
video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0)
return upload_file_to_comfyapi(
video_bytes_io, filename, upload_mime_type, auth_kwargs
)
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
"""
Prepares audio waveform for av library by converting to a contiguous numpy array.
Args:
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
Returns:
Contiguous numpy array of the audio waveform. If the audio was batched,
the first item is taken.
"""
if waveform.ndim != 3 or waveform.shape[0] != 1:
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
# If batch is > 1, take first item
if waveform.shape[0] > 1:
waveform = waveform[0]
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
if audio_data_np.dtype != np.float32:
audio_data_np = audio_data_np.astype(np.float32)
return audio_data_np
def audio_ndarray_to_bytesio(
audio_data_np: np.ndarray,
sample_rate: int,
container_format: str = "mp4",
codec_name: str = "aac",
) -> BytesIO:
"""
Encodes a numpy array of audio data into a BytesIO object.
"""
audio_bytes_io = io.BytesIO()
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
frame = av.AudioFrame.from_ndarray(
audio_data_np,
format="fltp",
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
)
frame.sample_rate = sample_rate
frame.pts = 0
for packet in audio_stream.encode(frame):
output_container.mux(packet)
# Flush stream
for packet in audio_stream.encode(None):
output_container.mux(packet)
audio_bytes_io.seek(0)
return audio_bytes_io
def upload_audio_to_comfyapi(
audio: AudioInput,
auth_kwargs: Optional[dict[str,str]] = None,
container_format: str = "mp4",
codec_name: str = "aac",
mime_type: str = "audio/mp4",
filename: str = "uploaded_audio.mp4",
) -> str:
"""
Uploads a single audio input to ComfyUI API and returns its download URL.
Encodes the raw waveform into the specified format before uploading.
Args:
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded audio file.
"""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
def upload_images_to_comfyapi(
image: torch.Tensor, max_images=8, auth_kwargs: Optional[dict[str,str]] = None, mime_type: Optional[str] = None
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first.
Args:
image: Input torch.Tensor image.
max_images: Maximum number of images to upload.
auth_kwargs: Optional authentication token(s).
mime_type: Optional MIME type for the image.
"""
# if batch, try to upload each file if max_images is greater than 0
idx_image = 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_length = 1
if is_batch:
batch_length = image.shape[0]
while True:
curr_image = image
if len(image.shape) > 3:
curr_image = image[idx_image]
# get BytesIO version of image
img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type)
# first, request upload/download urls from comfy API
if not mime_type:
request_object = UploadRequest(file_name=img_binary.name)
else:
request_object = UploadRequest(
file_name=img_binary.name, content_type=mime_type
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/customers/storage",
method=HttpMethod.POST,
request_model=UploadRequest,
response_model=UploadResponse,
),
request=request_object,
auth_kwargs=auth_kwargs,
)
response = operation.execute()
upload_response = ApiClient.upload_file(
response.upload_url, img_binary, content_type=mime_type
)
# verify success
try:
upload_response.raise_for_status()
except requests.exceptions.HTTPError as e:
raise ValueError(f"Could not upload one or more images: {e}") from e
# add download_url to list
download_urls.append(response.download_url)
idx_image += 1
# stop uploading additional files if done
if is_batch and max_images > 0:
if idx_image >= max_images:
break
if idx_image >= batch_length:
break
return download_urls
def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor,
upscale_method="nearest-exact", crop="disabled",
allow_gradient=True, add_channel_dim=False):
"""
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
"""
_, H, W, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1,1)
mask = common_upscale(mask, width=W, height=H, upscale_method=upscale_method, crop=crop)
mask = mask.movedim(1,-1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask
def validate_string(string: str, strip_whitespace=True, field_name="prompt", min_length=None, max_length=None):
if strip_whitespace:
string = string.strip()
if min_length and len(string) < min_length:
raise Exception(f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long.")
if max_length and len(string) > max_length:
raise Exception(f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long.")
if not string:
raise Exception(f"Field '{field_name}' cannot be empty.")

View File

@ -0,0 +1,17 @@
# generated by datamodel-codegen:
# filename: filtered-openapi.yaml
# timestamp: 2025-04-29T23:44:54+00:00
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel
from . import PixverseDto
class ResponseData(BaseModel):
ErrCode: Optional[int] = None
ErrMsg: Optional[str] = None
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None

View File

@ -0,0 +1,57 @@
# generated by datamodel-codegen:
# filename: filtered-openapi.yaml
# timestamp: 2025-04-29T23:44:54+00:00
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel, Field
class V2OpenAPII2VResp(BaseModel):
video_id: Optional[int] = Field(None, description='Video_id')
class V2OpenAPIT2VReq(BaseModel):
aspect_ratio: str = Field(
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
)
duration: int = Field(
...,
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
examples=[5],
)
model: str = Field(
..., description='Model version (only supports v3.5)', examples=['v3.5']
)
motion_mode: Optional[str] = Field(
'normal',
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
examples=['normal'],
)
negative_prompt: Optional[str] = Field(
None, description='Negative prompt\n', max_length=2048
)
prompt: str = Field(..., description='Prompt', max_length=2048)
quality: str = Field(
...,
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
examples=['540p'],
)
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
style: Optional[str] = Field(
None,
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
examples=['anime'],
)
template_id: Optional[int] = Field(
None,
description='Template ID (template_id must be activated before use)',
examples=[302325299692608],
)
water_mark: Optional[bool] = Field(
False,
description='Watermark (true: add watermark, false: no watermark)',
examples=[False],
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,156 @@
from __future__ import annotations
from enum import Enum
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, confloat, conint
class BFLOutputFormat(str, Enum):
png = 'png'
jpeg = 'jpeg'
class BFLFluxExpandImageRequest(BaseModel):
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image')
bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image')
left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image')
right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand')
class BFLFluxFillImageRequest(BaseModel):
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.')
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
class BFLFluxCannyImageRequest(BaseModel):
prompt: str = Field(..., description='Text prompt for image generation')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection')
canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection')
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
class BFLFluxDepthImageRequest(BaseModel):
prompt: str = Field(..., description='Text prompt for image generation')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
class BFLFluxProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.')
height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
# image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
# None, description='Blend between the prompt and the image prompt.'
# )
class BFLFluxProUltraGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.')
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
None, description='Blend between the prompt and the image prompt.'
)
class BFLFluxProGenerateResponse(BaseModel):
id: str = Field(..., description='The unique identifier for the generation task.')
polling_url: str = Field(..., description='URL to poll for the generation result.')
class BFLStatus(str, Enum):
task_not_found = "Task not found"
pending = "Pending"
request_moderated = "Request Moderated"
content_moderated = "Content Moderated"
ready = "Ready"
error = "Error"
class BFLFluxProStatusResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.")
status: BFLStatus = Field(..., description="The status of the task.")
result: Optional[Dict[str, Any]] = Field(
None, description="The result of the task (null if not completed)."
)
progress: confloat(ge=0.0, le=1.0) = Field(
..., description="The progress of the task (0.0 to 1.0)."
)
details: Optional[Dict[str, Any]] = Field(
None, description="Additional details about the task (null if not available)."
)

View File

@ -0,0 +1,635 @@
"""
API Client Framework for api.comfy.org.
This module provides a flexible framework for making API requests from ComfyUI nodes.
It supports both synchronous and asynchronous API operations with proper type validation.
Key Components:
--------------
1. ApiClient - Handles HTTP requests with authentication and error handling
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
3. ApiOperation - Executes a single synchronous API operation
Usage Examples:
--------------
# Example 1: Synchronous API Operation
# ------------------------------------
# For a simple API call that returns the result immediately:
# 1. Create the API client
api_client = ApiClient(
base_url="https://api.example.com",
auth_token="your_auth_token_here",
comfy_api_key="your_comfy_api_key_here",
timeout=30.0,
verify_ssl=True
)
# 2. Define the endpoint
user_info_endpoint = ApiEndpoint(
path="/v1/users/me",
method=HttpMethod.GET,
request_model=EmptyRequest, # No request body needed
response_model=UserProfile, # Pydantic model for the response
query_params=None
)
# 3. Create the request object
request = EmptyRequest()
# 4. Create and execute the operation
operation = ApiOperation(
endpoint=user_info_endpoint,
request=request
)
user_profile = operation.execute(client=api_client) # Returns immediately with the result
# Example 2: Asynchronous API Operation with Polling
# -------------------------------------------------
# For an API that starts a task and requires polling for completion:
# 1. Define the endpoints (initial request and polling)
generate_image_endpoint = ApiEndpoint(
path="/v1/images/generate",
method=HttpMethod.POST,
request_model=ImageGenerationRequest,
response_model=TaskCreatedResponse,
query_params=None
)
check_task_endpoint = ApiEndpoint(
path="/v1/tasks/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=ImageGenerationResult,
query_params=None
)
# 2. Create the request object
request = ImageGenerationRequest(
prompt="a beautiful sunset over mountains",
width=1024,
height=1024,
num_images=1
)
# 3. Create and execute the polling operation
operation = PollingOperation(
initial_endpoint=generate_image_endpoint,
initial_request=request,
poll_endpoint=check_task_endpoint,
task_id_field="task_id",
status_field="status",
completed_statuses=["completed"],
failed_statuses=["failed", "error"]
)
# This will make the initial request and then poll until completion
result = operation.execute(client=api_client) # Returns the final ImageGenerationResult when done
"""
from __future__ import annotations
import logging
import time
import io
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable
from enum import Enum
import json
import requests
from urllib.parse import urljoin
from pydantic import BaseModel, Field
from comfy.cli_args import args
from comfy import utils
T = TypeVar("T", bound=BaseModel)
R = TypeVar("R", bound=BaseModel)
P = TypeVar("P", bound=BaseModel) # For poll response
PROGRESS_BAR_MAX = 100
class EmptyRequest(BaseModel):
"""Base class for empty request bodies.
For GET requests, fields will be sent as query parameters."""
pass
class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload")
content_type: str | None = Field(
None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
)
class UploadResponse(BaseModel):
download_url: str = Field(..., description="URL to GET uploaded file")
upload_url: str = Field(..., description="URL to PUT file to upload")
class HttpMethod(str, Enum):
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
PATCH = "PATCH"
class ApiClient:
"""
Client for making HTTP requests to an API with authentication and error handling.
"""
def __init__(
self,
base_url: str,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
timeout: float = 3600.0,
verify_ssl: bool = True,
):
self.base_url = base_url
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
self.timeout = timeout
self.verify_ssl = verify_ssl
def _create_json_payload_args(
self,
data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
return {
"json": data,
"headers": headers,
}
def _create_form_data_args(
self,
data: Dict[str, Any],
files: Dict[str, Any],
headers: Optional[Dict[str, str]] = None,
multipart_parser = None,
) -> Dict[str, Any]:
if headers and "Content-Type" in headers:
del headers["Content-Type"]
if multipart_parser:
data = multipart_parser(data)
return {
"data": data,
"files": files,
"headers": headers,
}
def _create_urlencoded_form_data_args(
self,
data: Dict[str, Any],
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
headers = headers or {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
return {
"data": data,
"headers": headers,
}
def get_headers(self) -> Dict[str, str]:
"""Get headers for API requests, including authentication if available"""
headers = {"Content-Type": "application/json", "Accept": "application/json"}
if self.auth_token:
headers["Authorization"] = f"Bearer {self.auth_token}"
elif self.comfy_api_key:
headers["X-API-KEY"] = self.comfy_api_key
return headers
def request(
self,
method: str,
path: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
files: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
content_type: str = "application/json",
multipart_parser: Callable = None,
) -> Dict[str, Any]:
"""
Make an HTTP request to the API
Args:
method: HTTP method (GET, POST, etc.)
path: API endpoint path (will be joined with base_url)
params: Query parameters
data: body data
files: Files to upload
headers: Additional headers
content_type: Content type of the request. Defaults to application/json.
Returns:
Parsed JSON response
Raises:
requests.RequestException: If the request fails
"""
url = urljoin(self.base_url, path)
self.check_auth(self.auth_token, self.comfy_api_key)
# Combine default headers with any provided headers
request_headers = self.get_headers()
if headers:
request_headers.update(headers)
# Let requests handle the content type when files are present.
if files:
del request_headers["Content-Type"]
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
logging.debug(f"[DEBUG] Files: {files}")
logging.debug(f"[DEBUG] Params: {params}")
logging.debug(f"[DEBUG] Data: {data}")
if content_type == "application/x-www-form-urlencoded":
payload_args = self._create_urlencoded_form_data_args(data, request_headers)
elif content_type == "multipart/form-data":
payload_args = self._create_form_data_args(
data, files, request_headers, multipart_parser
)
else:
payload_args = self._create_json_payload_args(data, request_headers)
try:
response = requests.request(
method=method,
url=url,
params=params,
timeout=self.timeout,
verify=self.verify_ssl,
**payload_args,
)
# Raise exception for error status codes
response.raise_for_status()
except requests.ConnectionError:
raise Exception(
f"Unable to connect to the API server at {self.base_url}. Please check your internet connection or verify the service is available."
)
except requests.Timeout:
raise Exception(
f"Request timed out after {self.timeout} seconds. The server might be experiencing high load or the operation is taking longer than expected."
)
except requests.HTTPError as e:
status_code = e.response.status_code if hasattr(e, "response") else None
error_message = f"HTTP Error: {str(e)}"
# Try to extract detailed error message from JSON response
try:
if hasattr(e, "response") and e.response.content:
error_json = e.response.json()
if "error" in error_json and "message" in error_json["error"]:
error_message = f"API Error: {error_json['error']['message']}"
if "type" in error_json["error"]:
error_message += f" (Type: {error_json['error']['type']})"
else:
error_message = f"API Error: {error_json}"
except Exception as json_error:
# If we can't parse the JSON, fall back to the original error message
logging.debug(
f"[DEBUG] Failed to parse error response: {str(json_error)}"
)
logging.debug(f"[DEBUG] API Error: {error_message} (Status: {status_code})")
if hasattr(e, "response") and e.response.content:
logging.debug(f"[DEBUG] Response content: {e.response.content}")
if status_code == 401:
error_message = "Unauthorized: Please login first to use this node."
if status_code == 402:
error_message = "Payment Required: Please add credits to your account to use this node."
if status_code == 409:
error_message = "There is a problem with your account. Please contact support@comfy.org. "
if status_code == 429:
error_message = "Rate Limit Exceeded: Please try again later."
raise Exception(error_message)
# Parse and return JSON response
if response.content:
return response.json()
return {}
def check_auth(self, auth_token, comfy_api_key):
"""Verify that an auth token is present or comfy_api_key is present"""
if auth_token is None and comfy_api_key is None:
raise Exception("Unauthorized: Please login first to use this node.")
return auth_token or comfy_api_key
@staticmethod
def upload_file(
upload_url: str,
file: io.BytesIO | str,
content_type: str | None = None,
):
"""Upload a file to the API. Make sure the file has a filename equal to what the url expects.
Args:
upload_url: The URL to upload to
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
mime_type: Optional mime type to set for the upload
"""
headers = {}
if content_type:
headers["Content-Type"] = content_type
if isinstance(file, io.BytesIO):
file.seek(0) # Ensure we're at the start of the file
data = file.read()
return requests.put(upload_url, data=data, headers=headers)
elif isinstance(file, str):
with open(file, "rb") as f:
data = f.read()
return requests.put(upload_url, data=data, headers=headers)
class ApiEndpoint(Generic[T, R]):
"""Defines an API endpoint with its request and response types"""
def __init__(
self,
path: str,
method: HttpMethod,
request_model: Type[T],
response_model: Type[R],
query_params: Optional[Dict[str, Any]] = None,
):
"""Initialize an API endpoint definition.
Args:
path: The URL path for this endpoint, can include placeholders like {id}
method: The HTTP method to use (GET, POST, etc.)
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
query_params: Optional dictionary of query parameters to include in the request
"""
self.path = path
self.method = method
self.request_model = request_model
self.response_model = response_model
self.query_params = query_params or {}
class SynchronousOperation(Generic[T, R]):
"""
Represents a single synchronous API operation.
"""
def __init__(
self,
endpoint: ApiEndpoint[T, R],
request: T,
files: Optional[Dict[str, Any]] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str,str]] = None,
timeout: float = 604800.0,
verify_ssl: bool = True,
content_type: str = "application/json",
multipart_parser: Callable = None,
):
self.endpoint = endpoint
self.request = request
self.response = None
self.error = None
self.api_base: str = api_base or args.comfy_api_base
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
if auth_kwargs is not None:
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.timeout = timeout
self.verify_ssl = verify_ssl
self.files = files
self.content_type = content_type
self.multipart_parser = multipart_parser
def execute(self, client: Optional[ApiClient] = None) -> R:
"""Execute the API operation using the provided client or create one"""
try:
# Create client if not provided
if client is None:
client = ApiClient(
base_url=self.api_base,
auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
timeout=self.timeout,
verify_ssl=self.verify_ssl,
)
# Convert request model to dict, but use None for EmptyRequest
request_dict = (
None
if isinstance(self.request, EmptyRequest)
else self.request.model_dump(exclude_none=True)
)
if request_dict:
for key, value in request_dict.items():
if isinstance(value, Enum):
request_dict[key] = value.value
if request_dict:
for key, value in request_dict.items():
if isinstance(value, Enum):
request_dict[key] = value.value
# Debug log for request
logging.debug(
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
)
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
# Make the request
resp = client.request(
method=self.endpoint.method.value,
path=self.endpoint.path,
data=request_dict,
params=self.endpoint.query_params,
files=self.files,
content_type=self.content_type,
multipart_parser=self.multipart_parser
)
# Debug log for response
logging.debug("=" * 50)
logging.debug("[DEBUG] RESPONSE DETAILS:")
logging.debug("[DEBUG] Status Code: 200 (Success)")
logging.debug(f"[DEBUG] Response Body: {json.dumps(resp, indent=2)}")
logging.debug("=" * 50)
# Parse and return the response
return self._parse_response(resp)
except Exception as e:
logging.error(f"[DEBUG] API Exception: {str(e)}")
raise Exception(str(e))
def _parse_response(self, resp):
"""Parse response data - can be overridden by subclasses"""
# The response is already the complete object, don't extract just the "data" field
# as that would lose the outer structure (created timestamp, etc.)
# Parse response using the provided model
self.response = self.endpoint.response_model.model_validate(resp)
logging.debug(f"[DEBUG] Parsed Response: {self.response}")
return self.response
class TaskStatus(str, Enum):
"""Enum for task status values"""
COMPLETED = "completed"
FAILED = "failed"
PENDING = "pending"
class PollingOperation(Generic[T, R]):
"""
Represents an asynchronous API operation that requires polling for completion.
"""
def __init__(
self,
poll_endpoint: ApiEndpoint[EmptyRequest, R],
completed_statuses: list,
failed_statuses: list,
status_extractor: Callable[[R], str],
progress_extractor: Callable[[R], float] = None,
request: Optional[T] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str,str]] = None,
poll_interval: float = 5.0,
):
self.poll_endpoint = poll_endpoint
self.request = request
self.api_base: str = api_base or args.comfy_api_base
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
if auth_kwargs is not None:
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.poll_interval = poll_interval
# Polling configuration
self.status_extractor = status_extractor or (
lambda x: getattr(x, "status", None)
)
self.progress_extractor = progress_extractor
self.completed_statuses = completed_statuses
self.failed_statuses = failed_statuses
# For storing response data
self.final_response = None
self.error = None
def execute(self, client: Optional[ApiClient] = None) -> R:
"""Execute the polling operation using the provided client. If failed, raise an exception."""
try:
if client is None:
client = ApiClient(
base_url=self.api_base,
auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
)
return self._poll_until_complete(client)
except Exception as e:
raise Exception(f"Error during polling: {str(e)}")
def _check_task_status(self, response: R) -> TaskStatus:
"""Check task status using the status extractor function"""
try:
status = self.status_extractor(response)
if status in self.completed_statuses:
return TaskStatus.COMPLETED
elif status in self.failed_statuses:
return TaskStatus.FAILED
return TaskStatus.PENDING
except Exception as e:
logging.error(f"Error extracting status: {e}")
return TaskStatus.PENDING
def _poll_until_complete(self, client: ApiClient) -> R:
"""Poll until the task is complete"""
poll_count = 0
if self.progress_extractor:
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
while True:
try:
poll_count += 1
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
request_dict = (
self.request.model_dump(exclude_none=True)
if self.request is not None
else None
)
if poll_count == 1:
logging.debug(
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}"
)
logging.debug(
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}"
)
# Query task status
resp = client.request(
method=self.poll_endpoint.method.value,
path=self.poll_endpoint.path,
params=self.poll_endpoint.query_params,
data=request_dict,
)
# Parse response
response_obj = self.poll_endpoint.response_model.model_validate(resp)
# Check if task is complete
status = self._check_task_status(response_obj)
logging.debug(f"[DEBUG] Task Status: {status}")
# If progress extractor is provided, extract progress
if self.progress_extractor:
new_progress = self.progress_extractor(response_obj)
if new_progress is not None:
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
if status == TaskStatus.COMPLETED:
logging.debug("[DEBUG] Task completed successfully")
self.final_response = response_obj
if self.progress_extractor:
progress.update(100)
return self.final_response
elif status == TaskStatus.FAILED:
message = f"Task failed: {json.dumps(resp)}"
logging.error(f"[DEBUG] {message}")
raise Exception(message)
else:
logging.debug("[DEBUG] Task still pending, continuing to poll...")
# Wait before polling again
logging.debug(
f"[DEBUG] Waiting {self.poll_interval} seconds before next poll"
)
time.sleep(self.poll_interval)
except Exception as e:
logging.error(f"[DEBUG] Polling error: {str(e)}")
raise Exception(f"Error while polling: {str(e)}")

View File

@ -0,0 +1,253 @@
from __future__ import annotations
import torch
from enum import Enum
from typing import Optional, Union
from pydantic import BaseModel, Field, confloat
class LumaIO:
LUMA_REF = "LUMA_REF"
LUMA_CONCEPTS = "LUMA_CONCEPTS"
class LumaReference:
def __init__(self, image: torch.Tensor, weight: float):
self.image = image
self.weight = weight
def create_api_model(self, download_url: str):
return LumaImageRef(url=download_url, weight=self.weight)
class LumaReferenceChain:
def __init__(self, first_ref: LumaReference=None):
self.refs: list[LumaReference] = []
if first_ref:
self.refs.append(first_ref)
def add(self, luma_ref: LumaReference=None):
self.refs.append(luma_ref)
def create_api_model(self, download_urls: list[str], max_refs=4):
if len(self.refs) == 0:
return None
api_refs: list[LumaImageRef] = []
for ref, url in zip(self.refs, download_urls):
api_ref = LumaImageRef(url=url, weight=ref.weight)
api_refs.append(api_ref)
return api_refs
def clone(self):
c = LumaReferenceChain()
for ref in self.refs:
c.add(ref)
return c
class LumaConcept:
def __init__(self, key: str):
self.key = key
class LumaConceptChain:
def __init__(self, str_list: list[str] = None):
self.concepts: list[LumaConcept] = []
if str_list is not None:
for c in str_list:
if c != "None":
self.add(LumaConcept(key=c))
def add(self, concept: LumaConcept):
self.concepts.append(concept)
def create_api_model(self):
if len(self.concepts) == 0:
return None
api_concepts: list[LumaConceptObject] = []
for concept in self.concepts:
if concept.key == "None":
continue
api_concepts.append(LumaConceptObject(key=concept.key))
if len(api_concepts) == 0:
return None
return api_concepts
def clone(self):
c = LumaConceptChain()
for concept in self.concepts:
c.add(concept)
return c
def clone_and_merge(self, other: LumaConceptChain):
c = self.clone()
for concept in other.concepts:
c.add(concept)
return c
def get_luma_concepts(include_none=False):
concepts = []
if include_none:
concepts.append("None")
return concepts + [
"truck_left",
"pan_right",
"pedestal_down",
"low_angle",
"pedestal_up",
"selfie",
"pan_left",
"roll_right",
"zoom_in",
"over_the_shoulder",
"orbit_right",
"orbit_left",
"static",
"tiny_planet",
"high_angle",
"bolt_cam",
"dolly_zoom",
"overhead",
"zoom_out",
"handheld",
"roll_left",
"pov",
"aerial_drone",
"push_in",
"crane_down",
"truck_right",
"tilt_down",
"elevator_doors",
"tilt_up",
"ground_level",
"pull_out",
"aerial",
"crane_up",
"eye_level"
]
class LumaImageModel(str, Enum):
photon_1 = "photon-1"
photon_flash_1 = "photon-flash-1"
class LumaVideoModel(str, Enum):
ray_2 = "ray-2"
ray_flash_2 = "ray-flash-2"
ray_1_6 = "ray-1-6"
class LumaAspectRatio(str, Enum):
ratio_1_1 = "1:1"
ratio_16_9 = "16:9"
ratio_9_16 = "9:16"
ratio_4_3 = "4:3"
ratio_3_4 = "3:4"
ratio_21_9 = "21:9"
ratio_9_21 = "9:21"
class LumaVideoOutputResolution(str, Enum):
res_540p = "540p"
res_720p = "720p"
res_1080p = "1080p"
res_4k = "4k"
class LumaVideoModelOutputDuration(str, Enum):
dur_5s = "5s"
dur_9s = "9s"
class LumaGenerationType(str, Enum):
video = 'video'
image = 'image'
class LumaState(str, Enum):
queued = "queued"
dreaming = "dreaming"
completed = "completed"
failed = "failed"
class LumaAssets(BaseModel):
video: Optional[str] = Field(None, description='The URL of the video')
image: Optional[str] = Field(None, description='The URL of the image')
progress_video: Optional[str] = Field(None, description='The URL of the progress video')
class LumaImageRef(BaseModel):
'''Used for image gen'''
url: str = Field(..., description='The URL of the image reference')
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
class LumaImageReference(BaseModel):
'''Used for video gen'''
type: Optional[str] = Field('image', description='Input type, defaults to image')
url: str = Field(..., description='The URL of the image')
class LumaModifyImageRef(BaseModel):
url: str = Field(..., description='The URL of the image reference')
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
class LumaCharacterRef(BaseModel):
identity0: LumaImageIdentity = Field(..., description='The image identity object')
class LumaImageIdentity(BaseModel):
images: list[str] = Field(..., description='The URLs of the image identity')
class LumaGenerationReference(BaseModel):
type: str = Field('generation', description='Input type, defaults to generation')
id: str = Field(..., description='The ID of the generation')
class LumaKeyframes(BaseModel):
frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
class LumaConceptObject(BaseModel):
key: str = Field(..., description='Camera Concept name')
class LumaImageGenerationRequest(BaseModel):
prompt: str = Field(..., description='The prompt of the generation')
model: LumaImageModel = Field(LumaImageModel.photon_1, description='The image model used for the generation')
aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9, description='The aspect ratio of the generation')
image_ref: Optional[list[LumaImageRef]] = Field(None, description='List of image reference objects')
style_ref: Optional[list[LumaImageRef]] = Field(None, description='List of style reference objects')
character_ref: Optional[LumaCharacterRef] = Field(None, description='The image identity object')
modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description='The modify image reference object')
class LumaGenerationRequest(BaseModel):
prompt: str = Field(..., description='The prompt of the generation')
model: LumaVideoModel = Field(LumaVideoModel.ray_2, description='The video model used for the generation')
duration: Optional[LumaVideoModelOutputDuration] = Field(None, description='The duration of the generation')
aspect_ratio: Optional[LumaAspectRatio] = Field(None, description='The aspect ratio of the generation')
resolution: Optional[LumaVideoOutputResolution] = Field(None, description='The resolution of the generation')
loop: Optional[bool] = Field(None, description='Whether to loop the video')
keyframes: Optional[LumaKeyframes] = Field(None, description='The keyframes of the generation')
concepts: Optional[list[LumaConceptObject]] = Field(None, description='Camera Concepts to apply to generation')
class LumaGeneration(BaseModel):
id: str = Field(..., description='The ID of the generation')
generation_type: LumaGenerationType = Field(..., description='Generation type, image or video')
state: LumaState = Field(..., description='The state of the generation')
failure_reason: Optional[str] = Field(None, description='The reason for the state of the generation')
created_at: str = Field(..., description='The date and time when the generation was created')
assets: Optional[LumaAssets] = Field(None, description='The assets of the generation')
model: str = Field(..., description='The model used for the generation')
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation")

View File

@ -0,0 +1,146 @@
from __future__ import annotations
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
pixverse_templates = {
"Microwave": 324641385496960,
"Suit Swagger": 328545151283968,
"Anything, Robot": 313358700761536,
"Subject 3 Fever": 327828816843648,
"kiss kiss": 315446315336768,
}
class PixverseIO:
TEMPLATE = "PIXVERSE_TEMPLATE"
class PixverseStatus(int, Enum):
successful = 1
generating = 5
deleted = 6
contents_moderation = 7
failed = 8
class PixverseAspectRatio(str, Enum):
ratio_16_9 = "16:9"
ratio_4_3 = "4:3"
ratio_1_1 = "1:1"
ratio_3_4 = "3:4"
ratio_9_16 = "9:16"
class PixverseQuality(str, Enum):
res_360p = "360p"
res_540p = "540p"
res_720p = "720p"
res_1080p = "1080p"
class PixverseDuration(int, Enum):
dur_5 = 5
dur_8 = 8
class PixverseMotionMode(str, Enum):
normal = "normal"
fast = "fast"
class PixverseStyle(str, Enum):
anime = "anime"
animation_3d = "3d_animation"
clay = "clay"
comic = "comic"
cyberpunk = "cyberpunk"
# NOTE: forgoing descriptions for now in return for dev speed
class PixverseTextVideoRequest(BaseModel):
aspect_ratio: PixverseAspectRatio = Field(...)
quality: PixverseQuality = Field(...)
duration: PixverseDuration = Field(...)
model: Optional[str] = Field("v3.5")
motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal)
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
style: Optional[str] = Field(None)
template_id: Optional[int] = Field(None)
water_mark: Optional[bool] = Field(None)
class PixverseImageVideoRequest(BaseModel):
quality: PixverseQuality = Field(...)
duration: PixverseDuration = Field(...)
img_id: int = Field(...)
model: Optional[str] = Field("v3.5")
motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal)
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
style: Optional[str] = Field(None)
template_id: Optional[int] = Field(None)
water_mark: Optional[bool] = Field(None)
class PixverseTransitionVideoRequest(BaseModel):
quality: PixverseQuality = Field(...)
duration: PixverseDuration = Field(...)
first_frame_img: int = Field(...)
last_frame_img: int = Field(...)
model: Optional[str] = Field("v3.5")
motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal)
prompt: str = Field(...)
# negative_prompt: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
# style: Optional[str] = Field(None)
# template_id: Optional[int] = Field(None)
# water_mark: Optional[bool] = Field(None)
class PixverseImageUploadResponse(BaseModel):
ErrCode: Optional[int] = None
ErrMsg: Optional[str] = None
Resp: Optional[PixverseImgIdResponseObject] = Field(None, alias='Resp')
class PixverseImgIdResponseObject(BaseModel):
img_id: Optional[int] = None
class PixverseVideoResponse(BaseModel):
ErrCode: Optional[int] = Field(None)
ErrMsg: Optional[str] = Field(None)
Resp: Optional[PixverseVideoIdResponseObject] = Field(None)
class PixverseVideoIdResponseObject(BaseModel):
video_id: int = Field(..., description='Video_id')
class PixverseGenerationStatusResponse(BaseModel):
ErrCode: Optional[int] = Field(None)
ErrMsg: Optional[str] = Field(None)
Resp: Optional[PixverseGenerationStatusResponseObject] = Field(None)
class PixverseGenerationStatusResponseObject(BaseModel):
create_time: Optional[str] = Field(None)
id: Optional[int] = Field(None)
modify_time: Optional[str] = Field(None)
negative_prompt: Optional[str] = Field(None)
outputHeight: Optional[int] = Field(None)
outputWidth: Optional[int] = Field(None)
prompt: Optional[str] = Field(None)
resolution_ratio: Optional[int] = Field(None)
seed: Optional[int] = Field(None)
size: Optional[int] = Field(None)
status: Optional[int] = Field(None)
style: Optional[str] = Field(None)
url: Optional[str] = Field(None)

View File

@ -0,0 +1,262 @@
from __future__ import annotations
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field, conint, confloat
class RecraftColor:
def __init__(self, r: int, g: int, b: int):
self.color = [r, g, b]
def create_api_model(self):
return RecraftColorObject(rgb=self.color)
class RecraftColorChain:
def __init__(self):
self.colors: list[RecraftColor] = []
def get_first(self):
if len(self.colors) > 0:
return self.colors[0]
return None
def add(self, color: RecraftColor):
self.colors.append(color)
def create_api_model(self):
if not self.colors:
return None
colors_api = [x.create_api_model() for x in self.colors]
return colors_api
def clone(self):
c = RecraftColorChain()
for color in self.colors:
c.add(color)
return c
def clone_and_merge(self, other: RecraftColorChain):
c = self.clone()
for color in other.colors:
c.add(color)
return c
class RecraftControls:
def __init__(self, colors: RecraftColorChain=None, background_color: RecraftColorChain=None,
artistic_level: int=None, no_text: bool=None):
self.colors = colors
self.background_color = background_color
self.artistic_level = artistic_level
self.no_text = no_text
def create_api_model(self):
if self.colors is None and self.background_color is None and self.artistic_level is None and self.no_text is None:
return None
colors_api = None
background_color_api = None
if self.colors:
colors_api = self.colors.create_api_model()
if self.background_color:
first_background = self.background_color.get_first()
background_color_api = first_background.create_api_model() if first_background else None
return RecraftControlsObject(colors=colors_api, background_color=background_color_api,
artistic_level=self.artistic_level, no_text=self.no_text)
class RecraftStyle:
def __init__(self, style: str=None, substyle: str=None, style_id: str=None):
self.style = style
if substyle == "None":
substyle = None
self.substyle = substyle
self.style_id = style_id
class RecraftIO:
STYLEV3 = "RECRAFT_V3_STYLE"
COLOR = "RECRAFT_COLOR"
CONTROLS = "RECRAFT_CONTROLS"
class RecraftStyleV3(str, Enum):
#any = 'any' NOTE: this does not work for some reason... why?
realistic_image = 'realistic_image'
digital_illustration = 'digital_illustration'
vector_illustration = 'vector_illustration'
logo_raster = 'logo_raster'
def get_v3_substyles(style_v3: str, include_none=True) -> list[str]:
substyles: list[str] = []
if include_none:
substyles.append("None")
return substyles + dict_recraft_substyles_v3.get(style_v3, [])
dict_recraft_substyles_v3 = {
RecraftStyleV3.realistic_image: [
"b_and_w",
"enterprise",
"evening_light",
"faded_nostalgia",
"forest_life",
"hard_flash",
"hdr",
"motion_blur",
"mystic_naturalism",
"natural_light",
"natural_tones",
"organic_calm",
"real_life_glow",
"retro_realism",
"retro_snapshot",
"studio_portrait",
"urban_drama",
"village_realism",
"warm_folk"
],
RecraftStyleV3.digital_illustration: [
"2d_art_poster",
"2d_art_poster_2",
"antiquarian",
"bold_fantasy",
"child_book",
"child_books",
"cover",
"crosshatch",
"digital_engraving",
"engraving_color",
"expressionism",
"freehand_details",
"grain",
"grain_20",
"graphic_intensity",
"hand_drawn",
"hand_drawn_outline",
"handmade_3d",
"hard_comics",
"infantile_sketch",
"long_shadow",
"modern_folk",
"multicolor",
"neon_calm",
"noir",
"nostalgic_pastel",
"outline_details",
"pastel_gradient",
"pastel_sketch",
"pixel_art",
"plastic",
"pop_art",
"pop_renaissance",
"seamless",
"street_art",
"tablet_sketch",
"urban_glow",
"urban_sketching",
"vanilla_dreams",
"young_adult_book",
"young_adult_book_2"
],
RecraftStyleV3.vector_illustration: [
"bold_stroke",
"chemistry",
"colored_stencil",
"contour_pop_art",
"cosmics",
"cutout",
"depressive",
"editorial",
"emotional_flat",
"engraving",
"infographical",
"line_art",
"line_circuit",
"linocut",
"marker_outline",
"mosaic",
"naivector",
"roundish_flat",
"seamless",
"segmented_colors",
"sharp_contrast",
"thin",
"vector_photo",
"vivid_shapes"
],
RecraftStyleV3.logo_raster: [
"emblem_graffiti",
"emblem_pop_art",
"emblem_punk",
"emblem_stamp",
"emblem_vintage"
],
}
class RecraftModel(str, Enum):
recraftv3 = 'recraftv3'
recraftv2 = 'recraftv2'
class RecraftImageSize(str, Enum):
res_1024x1024 = '1024x1024'
res_1365x1024 = '1365x1024'
res_1024x1365 = '1024x1365'
res_1536x1024 = '1536x1024'
res_1024x1536 = '1024x1536'
res_1820x1024 = '1820x1024'
res_1024x1820 = '1024x1820'
res_1024x2048 = '1024x2048'
res_2048x1024 = '2048x1024'
res_1434x1024 = '1434x1024'
res_1024x1434 = '1024x1434'
res_1024x1280 = '1024x1280'
res_1280x1024 = '1280x1024'
res_1024x1707 = '1024x1707'
res_1707x1024 = '1707x1024'
class RecraftColorObject(BaseModel):
rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model')
class RecraftControlsObject(BaseModel):
colors: Optional[list[RecraftColorObject]] = Field(None, description='An array of preferable colors')
background_color: Optional[RecraftColorObject] = Field(None, description='Use given color as a desired background color')
no_text: Optional[bool] = Field(None, description='Do not embed text layouts')
artistic_level: Optional[conint(ge=0, le=5)] = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].')
class RecraftImageGenerationRequest(BaseModel):
prompt: str = Field(..., description='The text prompt describing the image to generate')
size: Optional[RecraftImageSize] = Field(None, description='The size of the generated image (e.g., "1024x1024")')
n: conint(ge=1, le=6) = Field(..., description='The number of images to generate')
negative_prompt: Optional[str] = Field(None, description='A text description of undesired elements on an image')
model: Optional[RecraftModel] = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
style: Optional[str] = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
substyle: Optional[str] = Field(None, description='The substyle to apply to the generated image, depending on the style input')
controls: Optional[RecraftControlsObject] = Field(None, description='A set of custom parameters to tweak generation process')
style_id: Optional[str] = Field(None, description='Use a previously uploaded style as a reference; UUID')
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
random_seed: Optional[int] = Field(None, description="Seed for video generation")
# text_layout
class RecraftReturnedObject(BaseModel):
image_id: str = Field(..., description='Unique identifier for the generated image')
url: str = Field(..., description='URL to access the generated image')
class RecraftImageGenerationResponse(BaseModel):
created: int = Field(..., description='Unix timestamp when the generation was created')
credits: int = Field(..., description='Number of credits used for the generation')
data: Optional[list[RecraftReturnedObject]] = Field(None, description='Array of generated image information')
image: Optional[RecraftReturnedObject] = Field(None, description='Single generated image')

View File

@ -0,0 +1,127 @@
from __future__ import annotations
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field, confloat
class StabilityFormat(str, Enum):
png = 'png'
jpeg = 'jpeg'
webp = 'webp'
class StabilityAspectRatio(str, Enum):
ratio_1_1 = "1:1"
ratio_16_9 = "16:9"
ratio_9_16 = "9:16"
ratio_3_2 = "3:2"
ratio_2_3 = "2:3"
ratio_5_4 = "5:4"
ratio_4_5 = "4:5"
ratio_21_9 = "21:9"
ratio_9_21 = "9:21"
def get_stability_style_presets(include_none=True):
presets = []
if include_none:
presets.append("None")
return presets + [x.value for x in StabilityStylePreset]
class StabilityStylePreset(str, Enum):
_3d_model = "3d-model"
analog_film = "analog-film"
anime = "anime"
cinematic = "cinematic"
comic_book = "comic-book"
digital_art = "digital-art"
enhance = "enhance"
fantasy_art = "fantasy-art"
isometric = "isometric"
line_art = "line-art"
low_poly = "low-poly"
modeling_compound = "modeling-compound"
neon_punk = "neon-punk"
origami = "origami"
photographic = "photographic"
pixel_art = "pixel-art"
tile_texture = "tile-texture"
class Stability_SD3_5_Model(str, Enum):
sd3_5_large = "sd3.5-large"
# sd3_5_large_turbo = "sd3.5-large-turbo"
sd3_5_medium = "sd3.5-medium"
class Stability_SD3_5_GenerationMode(str, Enum):
text_to_image = "text-to-image"
image_to_image = "image-to-image"
class StabilityStable3_5Request(BaseModel):
model: str = Field(...)
mode: str = Field(...)
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
aspect_ratio: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
output_format: Optional[str] = Field(StabilityFormat.png.value)
image: Optional[str] = Field(None)
style_preset: Optional[str] = Field(None)
cfg_scale: float = Field(...)
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
class StabilityUpscaleConservativeRequest(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
output_format: Optional[str] = Field(StabilityFormat.png.value)
image: Optional[str] = Field(None)
creativity: Optional[confloat(ge=0.2, le=0.5)] = Field(None)
class StabilityUpscaleCreativeRequest(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
output_format: Optional[str] = Field(StabilityFormat.png.value)
image: Optional[str] = Field(None)
creativity: Optional[confloat(ge=0.1, le=0.5)] = Field(None)
style_preset: Optional[str] = Field(None)
class StabilityStableUltraRequest(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
aspect_ratio: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
output_format: Optional[str] = Field(StabilityFormat.png.value)
image: Optional[str] = Field(None)
style_preset: Optional[str] = Field(None)
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
class StabilityStableUltraResponse(BaseModel):
image: Optional[str] = Field(None)
finish_reason: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class StabilityResultsGetResponse(BaseModel):
image: Optional[str] = Field(None)
finish_reason: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
id: Optional[str] = Field(None)
name: Optional[str] = Field(None)
errors: Optional[list[str]] = Field(None)
status: Optional[str] = Field(None)
result: Optional[str] = Field(None)
class StabilityAsyncResponse(BaseModel):
id: Optional[str] = Field(None)

10
comfy_api_nodes/canary.py Normal file
View File

@ -0,0 +1,10 @@
import av
ver = av.__version__.split(".")
if int(ver[0]) < 14:
raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.")
if int(ver[0]) == 14 and int(ver[1]) < 2:
raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.")
NODE_CLASS_MAPPINGS = {}

View File

@ -0,0 +1,116 @@
from enum import Enum
from pydantic.fields import FieldInfo
from pydantic import BaseModel
from pydantic_core import PydanticUndefined
from comfy.comfy_types.node_typing import IO, InputTypeOptions
NodeInput = tuple[IO, InputTypeOptions]
def _create_base_config(field_info: FieldInfo) -> InputTypeOptions:
config = {}
if hasattr(field_info, "default") and field_info.default is not PydanticUndefined:
config["default"] = field_info.default
if hasattr(field_info, "description") and field_info.description is not None:
config["tooltip"] = field_info.description
return config
def _get_number_constraints_config(field_info: FieldInfo) -> dict:
config = {}
if hasattr(field_info, "metadata"):
metadata = field_info.metadata
for constraint in metadata:
if hasattr(constraint, "ge"):
config["min"] = constraint.ge
if hasattr(constraint, "le"):
config["max"] = constraint.le
if hasattr(constraint, "multiple_of"):
config["step"] = constraint.multiple_of
return config
def _model_field_to_image_input(field_info: FieldInfo, **kwargs) -> NodeInput:
return IO.IMAGE, {
**_create_base_config(field_info),
**kwargs,
}
def _model_field_to_string_input(field_info: FieldInfo, **kwargs) -> NodeInput:
return IO.STRING, {
**_create_base_config(field_info),
**kwargs,
}
def _model_field_to_float_input(field_info: FieldInfo, **kwargs) -> NodeInput:
return IO.FLOAT, {
**_create_base_config(field_info),
**_get_number_constraints_config(field_info),
**kwargs,
}
def _model_field_to_int_input(field_info: FieldInfo, **kwargs) -> NodeInput:
return IO.INT, {
**_create_base_config(field_info),
**_get_number_constraints_config(field_info),
**kwargs,
}
def _model_field_to_combo_input(
field_info: FieldInfo, enum_type: type[Enum] = None, **kwargs
) -> NodeInput:
combo_config = {}
if enum_type is not None:
combo_config["options"] = [option.value for option in enum_type]
combo_config = {
**combo_config,
**_create_base_config(field_info),
**kwargs,
}
return IO.COMBO, combo_config
def model_field_to_node_input(
input_type: IO, base_model: type[BaseModel], field_name: str, **kwargs
) -> NodeInput:
"""
Maps a field from a Pydantic model to a Comfy node input.
Args:
input_type: The type of the input.
base_model: The Pydantic model to map the field from.
field_name: The name of the field to map.
**kwargs: Additional key/values to include in the input options.
Note:
For combo inputs, pass an `Enum` to the `enum_type` keyword argument to populate the options automatically.
Example:
>>> model_field_to_node_input(IO.STRING, MyModel, "my_field", multiline=True)
>>> model_field_to_node_input(IO.COMBO, MyModel, "my_field", enum_type=MyEnum)
>>> model_field_to_node_input(IO.FLOAT, MyModel, "my_field", slider=True)
"""
field_info: FieldInfo = base_model.model_fields[field_name]
result: NodeInput
if input_type == IO.IMAGE:
result = _model_field_to_image_input(field_info, **kwargs)
elif input_type == IO.STRING:
result = _model_field_to_string_input(field_info, **kwargs)
elif input_type == IO.FLOAT:
result = _model_field_to_float_input(field_info, **kwargs)
elif input_type == IO.INT:
result = _model_field_to_int_input(field_info, **kwargs)
elif input_type == IO.COMBO:
result = _model_field_to_combo_input(field_info, **kwargs)
else:
message = f"Invalid input type: {input_type}"
raise ValueError(message)
return result

View File

@ -0,0 +1,906 @@
import io
from inspect import cleandoc
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api_nodes.apis.bfl_api import (
BFLStatus,
BFLFluxExpandImageRequest,
BFLFluxFillImageRequest,
BFLFluxCannyImageRequest,
BFLFluxDepthImageRequest,
BFLFluxProGenerateRequest,
BFLFluxProUltraGenerateRequest,
BFLFluxProGenerateResponse,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import (
downscale_image_tensor,
validate_aspect_ratio,
process_image_response,
resize_mask_to_image,
validate_string,
)
import numpy as np
from PIL import Image
import requests
import torch
import base64
import time
def convert_mask_to_image(mask: torch.Tensor):
"""
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
"""
mask = mask.unsqueeze(-1)
mask = torch.cat([mask]*3, dim=-1)
return mask
def handle_bfl_synchronous_operation(
operation: SynchronousOperation, timeout_bfl_calls=360
):
response_api: BFLFluxProGenerateResponse = operation.execute()
return _poll_until_generated(
response_api.polling_url, timeout=timeout_bfl_calls
)
def _poll_until_generated(polling_url: str, timeout=360):
# used bfl-comfy-nodes to verify code implementation:
# https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main
start_time = time.time()
retries_404 = 0
max_retries_404 = 5
retry_404_seconds = 2
retry_202_seconds = 2
retry_pending_seconds = 1
request = requests.Request(method=HttpMethod.GET, url=polling_url)
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
while True:
response = requests.Session().send(request.prepare())
if response.status_code == 200:
result = response.json()
if result["status"] == BFLStatus.ready:
img_url = result["result"]["sample"]
img_response = requests.get(img_url)
return process_image_response(img_response)
elif result["status"] in [
BFLStatus.request_moderated,
BFLStatus.content_moderated,
]:
status = result["status"]
raise Exception(
f"BFL API did not return an image due to: {status}."
)
elif result["status"] == BFLStatus.error:
raise Exception(f"BFL API encountered an error: {result}.")
elif result["status"] == BFLStatus.pending:
time.sleep(retry_pending_seconds)
continue
elif response.status_code == 404:
if retries_404 < max_retries_404:
retries_404 += 1
time.sleep(retry_404_seconds)
continue
raise Exception(
f"BFL API could not find task after {max_retries_404} tries."
)
elif response.status_code == 202:
time.sleep(retry_202_seconds)
elif time.time() - start_time > timeout:
raise Exception(
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
)
else:
raise Exception(f"BFL API encountered an error: {response.json()}")
def convert_image_to_base64(image: torch.Tensor):
scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048)
# remove batch dimension if present
if len(scaled_image.shape) > 3:
scaled_image = scaled_image[0]
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="PNG")
return base64.b64encode(img_byte_arr.getvalue()).decode()
class FluxProUltraImageNode(ComfyNodeABC):
"""
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
"""
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"prompt_upsampling": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
"aspect_ratio": (
IO.STRING,
{
"default": "16:9",
"tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.",
},
),
"raw": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "When True, generate less processed, more natural-looking images.",
},
),
},
"optional": {
"image_prompt": (IO.IMAGE,),
"image_prompt_strength": (
IO.FLOAT,
{
"default": 0.1,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Blend between the prompt and the image prompt.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
@classmethod
def VALIDATE_INPUTS(cls, aspect_ratio: str):
try:
validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
except Exception as e:
return str(e)
return True
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/BFL"
def api_call(
self,
prompt: str,
aspect_ratio: str,
prompt_upsampling=False,
raw=False,
seed=0,
image_prompt=None,
image_prompt_strength=0.1,
**kwargs,
):
if image_prompt is None:
validate_string(prompt, strip_whitespace=False)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/bfl/flux-pro-1.1-ultra/generate",
method=HttpMethod.POST,
request_model=BFLFluxProUltraGenerateRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxProUltraGenerateRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
seed=seed,
aspect_ratio=validate_aspect_ratio(
aspect_ratio,
minimum_ratio=self.MINIMUM_RATIO,
maximum_ratio=self.MAXIMUM_RATIO,
minimum_ratio_str=self.MINIMUM_RATIO_STR,
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
),
raw=raw,
image_prompt=(
image_prompt
if image_prompt is None
else convert_image_to_base64(image_prompt)
),
image_prompt_strength=(
None if image_prompt is None else round(image_prompt_strength, 2)
),
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
return (output_image,)
class FluxProImageNode(ComfyNodeABC):
"""
Generates images synchronously based on prompt and resolution.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"prompt_upsampling": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
},
),
"width": (
IO.INT,
{
"default": 1024,
"min": 256,
"max": 1440,
"step": 32,
},
),
"height": (
IO.INT,
{
"default": 768,
"min": 256,
"max": 1440,
"step": 32,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"optional": {
"image_prompt": (IO.IMAGE,),
# "image_prompt_strength": (
# IO.FLOAT,
# {
# "default": 0.1,
# "min": 0.0,
# "max": 1.0,
# "step": 0.01,
# "tooltip": "Blend between the prompt and the image prompt.",
# },
# ),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/BFL"
def api_call(
self,
prompt: str,
prompt_upsampling,
width: int,
height: int,
seed=0,
image_prompt=None,
# image_prompt_strength=0.1,
**kwargs,
):
image_prompt = (
image_prompt
if image_prompt is None
else convert_image_to_base64(image_prompt)
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/bfl/flux-pro-1.1/generate",
method=HttpMethod.POST,
request_model=BFLFluxProGenerateRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxProGenerateRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
width=width,
height=height,
seed=seed,
image_prompt=image_prompt,
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
return (output_image,)
class FluxProExpandNode(ComfyNodeABC):
"""
Outpaints image based on prompt.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"prompt_upsampling": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
},
),
"top": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2048,
"tooltip": "Number of pixels to expand at the top of the image"
},
),
"bottom": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2048,
"tooltip": "Number of pixels to expand at the bottom of the image"
},
),
"left": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2048,
"tooltip": "Number of pixels to expand at the left side of the image"
},
),
"right": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2048,
"tooltip": "Number of pixels to expand at the right side of the image"
},
),
"guidance": (
IO.FLOAT,
{
"default": 60,
"min": 1.5,
"max": 100,
"tooltip": "Guidance strength for the image generation process"
},
),
"steps": (
IO.INT,
{
"default": 50,
"min": 15,
"max": 50,
"tooltip": "Number of steps for the image generation process"
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"optional": {
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/BFL"
def api_call(
self,
image: torch.Tensor,
prompt: str,
prompt_upsampling: bool,
top: int,
bottom: int,
left: int,
right: int,
steps: int,
guidance: float,
seed=0,
**kwargs,
):
image = convert_image_to_base64(image)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/bfl/flux-pro-1.0-expand/generate",
method=HttpMethod.POST,
request_model=BFLFluxExpandImageRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxExpandImageRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
top=top,
bottom=bottom,
left=left,
right=right,
steps=steps,
guidance=guidance,
seed=seed,
image=image,
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
return (output_image,)
class FluxProFillNode(ComfyNodeABC):
"""
Inpaints image based on mask and prompt.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"mask": (IO.MASK,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"prompt_upsampling": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
},
),
"guidance": (
IO.FLOAT,
{
"default": 60,
"min": 1.5,
"max": 100,
"tooltip": "Guidance strength for the image generation process"
},
),
"steps": (
IO.INT,
{
"default": 50,
"min": 15,
"max": 50,
"tooltip": "Number of steps for the image generation process"
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"optional": {
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/BFL"
def api_call(
self,
image: torch.Tensor,
mask: torch.Tensor,
prompt: str,
prompt_upsampling: bool,
steps: int,
guidance: float,
seed=0,
**kwargs,
):
# prepare mask
mask = resize_mask_to_image(mask, image)
mask = convert_image_to_base64(convert_mask_to_image(mask))
# make sure image will have alpha channel removed
image = convert_image_to_base64(image[:,:,:,:3])
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/bfl/flux-pro-1.0-fill/generate",
method=HttpMethod.POST,
request_model=BFLFluxFillImageRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxFillImageRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
steps=steps,
guidance=guidance,
seed=seed,
image=image,
mask=mask,
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
return (output_image,)
class FluxProCannyNode(ComfyNodeABC):
"""
Generate image using a control image (canny).
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"control_image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"prompt_upsampling": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
},
),
"canny_low_threshold": (
IO.FLOAT,
{
"default": 0.1,
"min": 0.01,
"max": 0.99,
"step": 0.01,
"tooltip": "Low threshold for Canny edge detection; ignored if skip_processing is True"
},
),
"canny_high_threshold": (
IO.FLOAT,
{
"default": 0.4,
"min": 0.01,
"max": 0.99,
"step": 0.01,
"tooltip": "High threshold for Canny edge detection; ignored if skip_processing is True"
},
),
"skip_preprocessing": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.",
},
),
"guidance": (
IO.FLOAT,
{
"default": 30,
"min": 1,
"max": 100,
"tooltip": "Guidance strength for the image generation process"
},
),
"steps": (
IO.INT,
{
"default": 50,
"min": 15,
"max": 50,
"tooltip": "Number of steps for the image generation process"
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"optional": {
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/BFL"
def api_call(
self,
control_image: torch.Tensor,
prompt: str,
prompt_upsampling: bool,
canny_low_threshold: float,
canny_high_threshold: float,
skip_preprocessing: bool,
steps: int,
guidance: float,
seed=0,
**kwargs,
):
control_image = convert_image_to_base64(control_image[:,:,:,:3])
preprocessed_image = None
# scale canny threshold between 0-500, to match BFL's API
def scale_value(value: float, min_val=0, max_val=500):
return min_val + value * (max_val - min_val)
canny_low_threshold = int(round(scale_value(canny_low_threshold)))
canny_high_threshold = int(round(scale_value(canny_high_threshold)))
if skip_preprocessing:
preprocessed_image = control_image
control_image = None
canny_low_threshold = None
canny_high_threshold = None
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/bfl/flux-pro-1.0-canny/generate",
method=HttpMethod.POST,
request_model=BFLFluxCannyImageRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxCannyImageRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
steps=steps,
guidance=guidance,
seed=seed,
control_image=control_image,
canny_low_threshold=canny_low_threshold,
canny_high_threshold=canny_high_threshold,
preprocessed_image=preprocessed_image,
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
return (output_image,)
class FluxProDepthNode(ComfyNodeABC):
"""
Generate image using a control image (depth).
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"control_image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"prompt_upsampling": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
},
),
"skip_preprocessing": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.",
},
),
"guidance": (
IO.FLOAT,
{
"default": 15,
"min": 1,
"max": 100,
"tooltip": "Guidance strength for the image generation process"
},
),
"steps": (
IO.INT,
{
"default": 50,
"min": 15,
"max": 50,
"tooltip": "Number of steps for the image generation process"
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"optional": {
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/BFL"
def api_call(
self,
control_image: torch.Tensor,
prompt: str,
prompt_upsampling: bool,
skip_preprocessing: bool,
steps: int,
guidance: float,
seed=0,
**kwargs,
):
control_image = convert_image_to_base64(control_image[:,:,:,:3])
preprocessed_image = None
if skip_preprocessing:
preprocessed_image = control_image
control_image = None
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/bfl/flux-pro-1.0-depth/generate",
method=HttpMethod.POST,
request_model=BFLFluxDepthImageRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxDepthImageRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
steps=steps,
guidance=guidance,
seed=seed,
control_image=control_image,
preprocessed_image=preprocessed_image,
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
return (output_image,)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"FluxProUltraImageNode": FluxProUltraImageNode,
# "FluxProImageNode": FluxProImageNode,
"FluxProExpandNode": FluxProExpandNode,
"FluxProFillNode": FluxProFillNode,
"FluxProCannyNode": FluxProCannyNode,
"FluxProDepthNode": FluxProDepthNode,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image",
# "FluxProImageNode": "Flux 1.1 [pro] Image",
"FluxProExpandNode": "Flux.1 Expand Image",
"FluxProFillNode": "Flux.1 Fill Image",
"FluxProCannyNode": "Flux.1 Canny Control Image",
"FluxProDepthNode": "Flux.1 Depth Control Image",
}

View File

@ -0,0 +1,779 @@
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from inspect import cleandoc
from PIL import Image
import numpy as np
import io
import torch
from comfy_api_nodes.apis import (
IdeogramGenerateRequest,
IdeogramGenerateResponse,
ImageRequest,
IdeogramV3Request,
IdeogramV3EditRequest,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
bytesio_to_image_tensor,
resize_mask_to_image,
)
V1_V1_RES_MAP = {
"Auto":"AUTO",
"512 x 1536":"RESOLUTION_512_1536",
"576 x 1408":"RESOLUTION_576_1408",
"576 x 1472":"RESOLUTION_576_1472",
"576 x 1536":"RESOLUTION_576_1536",
"640 x 1024":"RESOLUTION_640_1024",
"640 x 1344":"RESOLUTION_640_1344",
"640 x 1408":"RESOLUTION_640_1408",
"640 x 1472":"RESOLUTION_640_1472",
"640 x 1536":"RESOLUTION_640_1536",
"704 x 1152":"RESOLUTION_704_1152",
"704 x 1216":"RESOLUTION_704_1216",
"704 x 1280":"RESOLUTION_704_1280",
"704 x 1344":"RESOLUTION_704_1344",
"704 x 1408":"RESOLUTION_704_1408",
"704 x 1472":"RESOLUTION_704_1472",
"720 x 1280":"RESOLUTION_720_1280",
"736 x 1312":"RESOLUTION_736_1312",
"768 x 1024":"RESOLUTION_768_1024",
"768 x 1088":"RESOLUTION_768_1088",
"768 x 1152":"RESOLUTION_768_1152",
"768 x 1216":"RESOLUTION_768_1216",
"768 x 1232":"RESOLUTION_768_1232",
"768 x 1280":"RESOLUTION_768_1280",
"768 x 1344":"RESOLUTION_768_1344",
"832 x 960":"RESOLUTION_832_960",
"832 x 1024":"RESOLUTION_832_1024",
"832 x 1088":"RESOLUTION_832_1088",
"832 x 1152":"RESOLUTION_832_1152",
"832 x 1216":"RESOLUTION_832_1216",
"832 x 1248":"RESOLUTION_832_1248",
"864 x 1152":"RESOLUTION_864_1152",
"896 x 960":"RESOLUTION_896_960",
"896 x 1024":"RESOLUTION_896_1024",
"896 x 1088":"RESOLUTION_896_1088",
"896 x 1120":"RESOLUTION_896_1120",
"896 x 1152":"RESOLUTION_896_1152",
"960 x 832":"RESOLUTION_960_832",
"960 x 896":"RESOLUTION_960_896",
"960 x 1024":"RESOLUTION_960_1024",
"960 x 1088":"RESOLUTION_960_1088",
"1024 x 640":"RESOLUTION_1024_640",
"1024 x 768":"RESOLUTION_1024_768",
"1024 x 832":"RESOLUTION_1024_832",
"1024 x 896":"RESOLUTION_1024_896",
"1024 x 960":"RESOLUTION_1024_960",
"1024 x 1024":"RESOLUTION_1024_1024",
"1088 x 768":"RESOLUTION_1088_768",
"1088 x 832":"RESOLUTION_1088_832",
"1088 x 896":"RESOLUTION_1088_896",
"1088 x 960":"RESOLUTION_1088_960",
"1120 x 896":"RESOLUTION_1120_896",
"1152 x 704":"RESOLUTION_1152_704",
"1152 x 768":"RESOLUTION_1152_768",
"1152 x 832":"RESOLUTION_1152_832",
"1152 x 864":"RESOLUTION_1152_864",
"1152 x 896":"RESOLUTION_1152_896",
"1216 x 704":"RESOLUTION_1216_704",
"1216 x 768":"RESOLUTION_1216_768",
"1216 x 832":"RESOLUTION_1216_832",
"1232 x 768":"RESOLUTION_1232_768",
"1248 x 832":"RESOLUTION_1248_832",
"1280 x 704":"RESOLUTION_1280_704",
"1280 x 720":"RESOLUTION_1280_720",
"1280 x 768":"RESOLUTION_1280_768",
"1280 x 800":"RESOLUTION_1280_800",
"1312 x 736":"RESOLUTION_1312_736",
"1344 x 640":"RESOLUTION_1344_640",
"1344 x 704":"RESOLUTION_1344_704",
"1344 x 768":"RESOLUTION_1344_768",
"1408 x 576":"RESOLUTION_1408_576",
"1408 x 640":"RESOLUTION_1408_640",
"1408 x 704":"RESOLUTION_1408_704",
"1472 x 576":"RESOLUTION_1472_576",
"1472 x 640":"RESOLUTION_1472_640",
"1472 x 704":"RESOLUTION_1472_704",
"1536 x 512":"RESOLUTION_1536_512",
"1536 x 576":"RESOLUTION_1536_576",
"1536 x 640":"RESOLUTION_1536_640",
}
V1_V2_RATIO_MAP = {
"1:1":"ASPECT_1_1",
"4:3":"ASPECT_4_3",
"3:4":"ASPECT_3_4",
"16:9":"ASPECT_16_9",
"9:16":"ASPECT_9_16",
"2:1":"ASPECT_2_1",
"1:2":"ASPECT_1_2",
"3:2":"ASPECT_3_2",
"2:3":"ASPECT_2_3",
"4:5":"ASPECT_4_5",
"5:4":"ASPECT_5_4",
}
V3_RATIO_MAP = {
"1:3":"1x3",
"3:1":"3x1",
"1:2":"1x2",
"2:1":"2x1",
"9:16":"9x16",
"16:9":"16x9",
"10:16":"10x16",
"16:10":"16x10",
"2:3":"2x3",
"3:2":"3x2",
"3:4":"3x4",
"4:3":"4x3",
"4:5":"4x5",
"5:4":"5x4",
"1:1":"1x1",
}
V3_RESOLUTIONS= [
"Auto",
"512x1536",
"576x1408",
"576x1472",
"576x1536",
"640x1344",
"640x1408",
"640x1472",
"640x1536",
"704x1152",
"704x1216",
"704x1280",
"704x1344",
"704x1408",
"704x1472",
"736x1312",
"768x1088",
"768x1216",
"768x1280",
"768x1344",
"800x1280",
"832x960",
"832x1024",
"832x1088",
"832x1152",
"832x1216",
"832x1248",
"864x1152",
"896x960",
"896x1024",
"896x1088",
"896x1120",
"896x1152",
"960x832",
"960x896",
"960x1024",
"960x1088",
"1024x832",
"1024x896",
"1024x960",
"1024x1024",
"1088x768",
"1088x832",
"1088x896",
"1088x960",
"1120x896",
"1152x704",
"1152x832",
"1152x864",
"1152x896",
"1216x704",
"1216x768",
"1216x832",
"1248x832",
"1280x704",
"1280x768",
"1280x800",
"1312x736",
"1344x640",
"1344x704",
"1344x768",
"1408x576",
"1408x640",
"1408x704",
"1472x576",
"1472x640",
"1472x704",
"1536x512",
"1536x576",
"1536x640"
]
def download_and_process_images(image_urls):
"""Helper function to download and process multiple images from URLs"""
# Initialize list to store image tensors
image_tensors = []
for image_url in image_urls:
# Using functions from apinode_utils.py to handle downloading and processing
image_bytesio = download_url_to_bytesio(image_url) # Download image content to BytesIO
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
image_tensors.append(img_tensor)
# Stack tensors to match (N, width, height, channels)
if image_tensors:
stacked_tensors = torch.cat(image_tensors, dim=0)
else:
raise Exception("No valid images were processed")
return stacked_tensors
class IdeogramV1(ComfyNodeABC):
"""
Generates images using the Ideogram V1 model.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"turbo": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
}
),
},
"optional": {
"aspect_ratio": (
IO.COMBO,
{
"options": list(V1_V2_RATIO_MAP.keys()),
"default": "1:1",
"tooltip": "The aspect ratio for image generation.",
},
),
"magic_prompt_option": (
IO.COMBO,
{
"options": ["AUTO", "ON", "OFF"],
"default": "AUTO",
"tooltip": "Determine if MagicPrompt should be used in generation",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"step": 1,
"control_after_generate": True,
"display": "number",
},
),
"negative_prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Description of what to exclude from the image",
},
),
"num_images": (
IO.INT,
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node/image/Ideogram/v1"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
def api_call(
self,
prompt,
turbo=False,
aspect_ratio="1:1",
magic_prompt_option="AUTO",
seed=0,
negative_prompt="",
num_images=1,
**kwargs,
):
# Determine the model based on turbo setting
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
model = "V_1_TURBO" if turbo else "V_1"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
image_request=ImageRequest(
prompt=prompt,
model=model,
num_images=num_images,
seed=seed,
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
magic_prompt_option=(
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
negative_prompt=negative_prompt if negative_prompt else None,
)
),
auth_kwargs=kwargs,
)
response = operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
return (download_and_process_images(image_urls),)
class IdeogramV2(ComfyNodeABC):
"""
Generates images using the Ideogram V2 model.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"turbo": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
}
),
},
"optional": {
"aspect_ratio": (
IO.COMBO,
{
"options": list(V1_V2_RATIO_MAP.keys()),
"default": "1:1",
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
},
),
"resolution": (
IO.COMBO,
{
"options": list(V1_V1_RES_MAP.keys()),
"default": "Auto",
"tooltip": "The resolution for image generation. If not set to AUTO, this overrides the aspect_ratio setting.",
},
),
"magic_prompt_option": (
IO.COMBO,
{
"options": ["AUTO", "ON", "OFF"],
"default": "AUTO",
"tooltip": "Determine if MagicPrompt should be used in generation",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"step": 1,
"control_after_generate": True,
"display": "number",
},
),
"style_type": (
IO.COMBO,
{
"options": ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
"default": "NONE",
"tooltip": "Style type for generation (V2 only)",
},
),
"negative_prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Description of what to exclude from the image",
},
),
"num_images": (
IO.INT,
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
),
#"color_palette": (
# IO.STRING,
# {
# "multiline": False,
# "default": "",
# "tooltip": "Color palette preset name or hex colors with weights",
# },
#),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node/image/Ideogram/v2"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
def api_call(
self,
prompt,
turbo=False,
aspect_ratio="1:1",
resolution="Auto",
magic_prompt_option="AUTO",
seed=0,
style_type="NONE",
negative_prompt="",
num_images=1,
color_palette="",
**kwargs,
):
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
resolution = V1_V1_RES_MAP.get(resolution, None)
# Determine the model based on turbo setting
model = "V_2_TURBO" if turbo else "V_2"
# Handle resolution vs aspect_ratio logic
# If resolution is not AUTO, it overrides aspect_ratio
final_resolution = None
final_aspect_ratio = None
if resolution != "AUTO":
final_resolution = resolution
else:
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
image_request=ImageRequest(
prompt=prompt,
model=model,
num_images=num_images,
seed=seed,
aspect_ratio=final_aspect_ratio,
resolution=final_resolution,
magic_prompt_option=(
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
style_type=style_type if style_type != "NONE" else None,
negative_prompt=negative_prompt if negative_prompt else None,
color_palette=color_palette if color_palette else None,
)
),
auth_kwargs=kwargs,
)
response = operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
return (download_and_process_images(image_urls),)
class IdeogramV3(ComfyNodeABC):
"""
Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation or editing",
},
),
},
"optional": {
"image": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional reference image for image editing.",
},
),
"mask": (
IO.MASK,
{
"default": None,
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
},
),
"aspect_ratio": (
IO.COMBO,
{
"options": list(V3_RATIO_MAP.keys()),
"default": "1:1",
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
},
),
"resolution": (
IO.COMBO,
{
"options": V3_RESOLUTIONS,
"default": "Auto",
"tooltip": "The resolution for image generation. If not set to Auto, this overrides the aspect_ratio setting.",
},
),
"magic_prompt_option": (
IO.COMBO,
{
"options": ["AUTO", "ON", "OFF"],
"default": "AUTO",
"tooltip": "Determine if MagicPrompt should be used in generation",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"step": 1,
"control_after_generate": True,
"display": "number",
},
),
"num_images": (
IO.INT,
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
),
"rendering_speed": (
IO.COMBO,
{
"options": ["BALANCED", "TURBO", "QUALITY"],
"default": "BALANCED",
"tooltip": "Controls the trade-off between generation speed and quality",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node/image/Ideogram/v3"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
def api_call(
self,
prompt,
image=None,
mask=None,
resolution="Auto",
aspect_ratio="1:1",
magic_prompt_option="AUTO",
seed=0,
num_images=1,
rendering_speed="BALANCED",
**kwargs,
):
# Check if both image and mask are provided for editing mode
if image is not None and mask is not None:
# Edit mode
path = "/proxy/ideogram/ideogram-v3/edit"
# Process image and mask
input_tensor = image.squeeze().cpu()
# Resize mask to match image dimension
mask = resize_mask_to_image(mask, image, allow_gradient=False)
# Invert mask, as Ideogram API will edit black areas instead of white areas (opposite of convention).
mask = 1.0 - mask
# Validate mask dimensions match image
if mask.shape[1:] != image.shape[1:-1]:
raise Exception("Mask and Image must be the same size")
# Process image
img_np = (input_tensor.numpy() * 255).astype(np.uint8)
img = Image.fromarray(img_np)
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
img_binary = img_byte_arr
img_binary.name = "image.png"
# Process mask - white areas will be replaced
mask_np = (mask.squeeze().cpu().numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
mask_byte_arr = io.BytesIO()
mask_img.save(mask_byte_arr, format="PNG")
mask_byte_arr.seek(0)
mask_binary = mask_byte_arr
mask_binary.name = "mask.png"
# Create edit request
edit_request = IdeogramV3EditRequest(
prompt=prompt,
rendering_speed=rendering_speed,
)
# Add optional parameters
if magic_prompt_option != "AUTO":
edit_request.magic_prompt = magic_prompt_option
if seed != 0:
edit_request.seed = seed
if num_images > 1:
edit_request.num_images = num_images
# Execute the operation for edit mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=IdeogramV3EditRequest,
response_model=IdeogramGenerateResponse,
),
request=edit_request,
files={
"image": img_binary,
"mask": mask_binary,
},
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
elif image is not None or mask is not None:
# If only one of image or mask is provided, raise an error
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
else:
# Generation mode
path = "/proxy/ideogram/ideogram-v3/generate"
# Create generation request
gen_request = IdeogramV3Request(
prompt=prompt,
rendering_speed=rendering_speed,
)
# Handle resolution vs aspect ratio
if resolution != "Auto":
gen_request.resolution = resolution
elif aspect_ratio != "1:1":
v3_aspect = V3_RATIO_MAP.get(aspect_ratio)
if v3_aspect:
gen_request.aspect_ratio = v3_aspect
# Add optional parameters
if magic_prompt_option != "AUTO":
gen_request.magic_prompt = magic_prompt_option
if seed != 0:
gen_request.seed = seed
if num_images > 1:
gen_request.num_images = num_images
# Execute the operation for generation mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=IdeogramV3Request,
response_model=IdeogramGenerateResponse,
),
request=gen_request,
auth_kwargs=kwargs,
)
# Execute the operation and process response
response = operation.execute()
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
return (download_and_process_images(image_urls),)
NODE_CLASS_MAPPINGS = {
"IdeogramV1": IdeogramV1,
"IdeogramV2": IdeogramV2,
"IdeogramV3": IdeogramV3,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"IdeogramV1": "Ideogram V1",
"IdeogramV2": "Ideogram V2",
"IdeogramV3": "Ideogram V3",
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,704 @@
from __future__ import annotations
from inspect import cleandoc
from typing import Optional
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis.luma_api import (
LumaImageModel,
LumaVideoModel,
LumaVideoOutputResolution,
LumaVideoModelOutputDuration,
LumaAspectRatio,
LumaState,
LumaImageGenerationRequest,
LumaGenerationRequest,
LumaGeneration,
LumaCharacterRef,
LumaModifyImageRef,
LumaImageIdentity,
LumaReference,
LumaReferenceChain,
LumaImageReference,
LumaKeyframes,
LumaConceptChain,
LumaIO,
get_luma_concepts,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi,
process_image_response,
validate_string,
)
import requests
import torch
from io import BytesIO
class LumaReferenceNode(ComfyNodeABC):
"""
Holds an image and weight for use with Luma Generate Image node.
"""
RETURN_TYPES = (LumaIO.LUMA_REF,)
RETURN_NAMES = ("luma_ref",)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "create_luma_reference"
CATEGORY = "api node/image/Luma"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (
IO.IMAGE,
{
"tooltip": "Image to use as reference.",
},
),
"weight": (
IO.FLOAT,
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Weight of image reference.",
},
),
},
"optional": {"luma_ref": (LumaIO.LUMA_REF,)},
}
def create_luma_reference(
self, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
):
if luma_ref is not None:
luma_ref = luma_ref.clone()
else:
luma_ref = LumaReferenceChain()
luma_ref.add(LumaReference(image=image, weight=round(weight, 2)))
return (luma_ref,)
class LumaConceptsNode(ComfyNodeABC):
"""
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
"""
RETURN_TYPES = (LumaIO.LUMA_CONCEPTS,)
RETURN_NAMES = ("luma_concepts",)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "create_concepts"
CATEGORY = "api node/video/Luma"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"concept1": (get_luma_concepts(include_none=True),),
"concept2": (get_luma_concepts(include_none=True),),
"concept3": (get_luma_concepts(include_none=True),),
"concept4": (get_luma_concepts(include_none=True),),
},
"optional": {
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to add to the ones chosen here."
},
),
},
}
def create_concepts(
self,
concept1: str,
concept2: str,
concept3: str,
concept4: str,
luma_concepts: LumaConceptChain = None,
):
chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4])
if luma_concepts is not None:
chain = luma_concepts.clone_and_merge(chain)
return (chain,)
class LumaImageGenerationNode(ComfyNodeABC):
"""
Generates images synchronously based on prompt and aspect ratio.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Luma"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"model": ([model.value for model in LumaImageModel],),
"aspect_ratio": (
[ratio.value for ratio in LumaAspectRatio],
{
"default": LumaAspectRatio.ratio_16_9,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
"style_image_weight": (
IO.FLOAT,
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Weight of style image. Ignored if no style_image provided.",
},
),
},
"optional": {
"image_luma_ref": (
LumaIO.LUMA_REF,
{
"tooltip": "Luma Reference node connection to influence generation with input images; up to 4 images can be considered."
},
),
"style_image": (
IO.IMAGE,
{"tooltip": "Style reference image; only 1 image will be used."},
),
"character_image": (
IO.IMAGE,
{
"tooltip": "Character reference images; can be a batch of multiple, up to 4 images can be considered."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
prompt: str,
model: str,
aspect_ratio: str,
seed,
style_image_weight: float,
image_luma_ref: LumaReferenceChain = None,
style_image: torch.Tensor = None,
character_image: torch.Tensor = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=True, min_length=3)
# handle image_luma_ref
api_image_ref = None
if image_luma_ref is not None:
api_image_ref = self._convert_luma_refs(
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
)
# handle style_luma_ref
api_style_ref = None
if style_image is not None:
api_style_ref = self._convert_style_image(
style_image, weight=style_image_weight, auth_kwargs=kwargs,
)
# handle character_ref images
character_ref = None
if character_image is not None:
download_urls = upload_images_to_comfyapi(
character_image, max_images=4, auth_kwargs=kwargs,
)
character_ref = LumaCharacterRef(
identity0=LumaImageIdentity(images=download_urls)
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations/image",
method=HttpMethod.POST,
request_model=LumaImageGenerationRequest,
response_model=LumaGeneration,
),
request=LumaImageGenerationRequest(
prompt=prompt,
model=model,
aspect_ratio=aspect_ratio,
image_ref=api_image_ref,
style_ref=api_style_ref,
character_ref=character_ref,
),
auth_kwargs=kwargs,
)
response_api: LumaGeneration = operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
img_response = requests.get(response_poll.assets.image)
img = process_image_response(img_response)
return (img,)
def _convert_luma_refs(
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
):
luma_urls = []
ref_count = 0
for ref in luma_ref.refs:
download_urls = upload_images_to_comfyapi(
ref.image, max_images=1, auth_kwargs=auth_kwargs
)
luma_urls.append(download_urls[0])
ref_count += 1
if ref_count >= max_refs:
break
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
def _convert_style_image(
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
):
chain = LumaReferenceChain(
first_ref=LumaReference(image=style_image, weight=weight)
)
return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
class LumaImageModifyNode(ComfyNodeABC):
"""
Modifies images synchronously based on prompt and aspect ratio.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Luma"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation",
},
),
"image_weight": (
IO.FLOAT,
{
"default": 0.1,
"min": 0.0,
"max": 0.98,
"step": 0.01,
"tooltip": "Weight of the image; the closer to 1.0, the less the image will be modified.",
},
),
"model": ([model.value for model in LumaImageModel],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
prompt: str,
model: str,
image: torch.Tensor,
image_weight: float,
seed,
**kwargs,
):
# first, upload image
download_urls = upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs,
)
image_url = download_urls[0]
# next, make Luma call with download url provided
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations/image",
method=HttpMethod.POST,
request_model=LumaImageGenerationRequest,
response_model=LumaGeneration,
),
request=LumaImageGenerationRequest(
prompt=prompt,
model=model,
modify_image_ref=LumaModifyImageRef(
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
),
),
auth_kwargs=kwargs,
)
response_api: LumaGeneration = operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
img_response = requests.get(response_poll.assets.image)
img = process_image_response(img_response)
return (img,)
class LumaTextToVideoGenerationNode(ComfyNodeABC):
"""
Generates videos synchronously based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/Luma"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"model": ([model.value for model in LumaVideoModel],),
"aspect_ratio": (
[ratio.value for ratio in LumaAspectRatio],
{
"default": LumaAspectRatio.ratio_16_9,
},
),
"resolution": (
[resolution.value for resolution in LumaVideoOutputResolution],
{
"default": LumaVideoOutputResolution.res_540p,
},
),
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
"loop": (
IO.BOOLEAN,
{
"default": False,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
prompt: str,
model: str,
aspect_ratio: str,
resolution: str,
duration: str,
loop: bool,
seed,
luma_concepts: LumaConceptChain = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False, min_length=3)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations",
method=HttpMethod.POST,
request_model=LumaGenerationRequest,
response_model=LumaGeneration,
),
request=LumaGenerationRequest(
prompt=prompt,
model=model,
resolution=resolution,
aspect_ratio=aspect_ratio,
duration=duration,
loop=loop,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=kwargs,
)
response_api: LumaGeneration = operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
vid_response = requests.get(response_poll.assets.video)
return (VideoFromFile(BytesIO(vid_response.content)),)
class LumaImageToVideoGenerationNode(ComfyNodeABC):
"""
Generates videos synchronously based on prompt, input images, and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/Luma"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"model": ([model.value for model in LumaVideoModel],),
# "aspect_ratio": ([ratio.value for ratio in LumaAspectRatio], {
# "default": LumaAspectRatio.ratio_16_9,
# }),
"resolution": (
[resolution.value for resolution in LumaVideoOutputResolution],
{
"default": LumaVideoOutputResolution.res_540p,
},
),
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
"loop": (
IO.BOOLEAN,
{
"default": False,
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
},
),
},
"optional": {
"first_image": (
IO.IMAGE,
{"tooltip": "First frame of generated video."},
),
"last_image": (IO.IMAGE, {"tooltip": "Last frame of generated video."}),
"luma_concepts": (
LumaIO.LUMA_CONCEPTS,
{
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
prompt: str,
model: str,
resolution: str,
duration: str,
loop: bool,
seed,
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
luma_concepts: LumaConceptChain = None,
**kwargs,
):
if first_image is None and last_image is None:
raise Exception(
"At least one of first_image and last_image requires an input."
)
keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/luma/generations",
method=HttpMethod.POST,
request_model=LumaGenerationRequest,
response_model=LumaGeneration,
),
request=LumaGenerationRequest(
prompt=prompt,
model=model,
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
resolution=resolution,
duration=duration,
loop=loop,
keyframes=keyframes,
concepts=luma_concepts.create_api_model() if luma_concepts else None,
),
auth_kwargs=kwargs,
)
response_api: LumaGeneration = operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
vid_response = requests.get(response_poll.assets.video)
return (VideoFromFile(BytesIO(vid_response.content)),)
def _convert_to_keyframes(
self,
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
auth_kwargs: Optional[dict[str,str]] = None,
):
if first_image is None and last_image is None:
return None
frame0 = None
frame1 = None
if first_image is not None:
download_urls = upload_images_to_comfyapi(
first_image, max_images=1, auth_kwargs=auth_kwargs,
)
frame0 = LumaImageReference(type="image", url=download_urls[0])
if last_image is not None:
download_urls = upload_images_to_comfyapi(
last_image, max_images=1, auth_kwargs=auth_kwargs,
)
frame1 = LumaImageReference(type="image", url=download_urls[0])
return LumaKeyframes(frame0=frame0, frame1=frame1)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"LumaImageNode": LumaImageGenerationNode,
"LumaImageModifyNode": LumaImageModifyNode,
"LumaVideoNode": LumaTextToVideoGenerationNode,
"LumaImageToVideoNode": LumaImageToVideoGenerationNode,
"LumaReferenceNode": LumaReferenceNode,
"LumaConceptsNode": LumaConceptsNode,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"LumaImageNode": "Luma Text to Image",
"LumaImageModifyNode": "Luma Image to Image",
"LumaVideoNode": "Luma Text to Video",
"LumaImageToVideoNode": "Luma Image to Video",
"LumaReferenceNode": "Luma Reference",
"LumaConceptsNode": "Luma Concepts",
}

View File

@ -0,0 +1,309 @@
from comfy.comfy_types.node_typing import IO
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
MinimaxVideoGenerationRequest,
MinimaxVideoGenerationResponse,
MinimaxFileRetrieveResponse,
MinimaxTaskResultResponse,
SubjectReferenceItem,
Model
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
upload_images_to_comfyapi,
validate_string,
)
import torch
import logging
class MinimaxTextToVideoNode:
"""
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt_text": (
"STRING",
{
"multiline": True,
"default": "",
"tooltip": "Text prompt to guide the video generation",
},
),
"model": (
[
"T2V-01",
"T2V-01-Director",
],
{
"default": "T2V-01",
"tooltip": "Model to use for video generation",
},
),
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
RETURN_TYPES = ("VIDEO",)
DESCRIPTION = "Generates videos from prompts using MiniMax's API"
FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax"
API_NODE = True
OUTPUT_NODE = True
def generate_video(
self,
prompt_text,
seed=0,
model="T2V-01",
image: torch.Tensor=None, # used for ImageToVideo
subject: torch.Tensor=None, # used for SubjectToVideo
**kwargs,
):
'''
Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments.
'''
if image is None:
validate_string(prompt_text, field_name="prompt_text")
# upload image, if passed in
image_url = None
if image is not None:
image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0]
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
subject_reference = None
if subject is not None:
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0]
subject_reference = [SubjectReferenceItem(image=subject_url)]
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/video_generation",
method=HttpMethod.POST,
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
model=Model(model),
prompt=prompt_text,
callback_url=None,
first_frame_image=image_url,
subject_reference=subject_reference,
prompt_optimizer=None,
),
auth_kwargs=kwargs,
)
response = video_generate_operation.execute()
task_id = response.task_id
if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}")
video_generate_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/minimax/query/video_generation",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value,
auth_kwargs=kwargs,
)
task_result = video_generate_operation.execute()
file_id = task_result.file_id
if file_id is None:
raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/files/retrieve",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=kwargs,
)
file_result = file_retrieve_operation.execute()
file_url = file_result.file.download_url
if file_url is None:
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info(f"Generated video URL: {file_url}")
video_io = download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return (VideoFromFile(video_io),)
class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (
IO.IMAGE,
{
"tooltip": "Image to use as first frame of video generation"
},
),
"prompt_text": (
"STRING",
{
"multiline": True,
"default": "",
"tooltip": "Text prompt to guide the video generation",
},
),
"model": (
[
"I2V-01-Director",
"I2V-01",
"I2V-01-live",
],
{
"default": "I2V-01",
"tooltip": "Model to use for video generation",
},
),
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
RETURN_TYPES = ("VIDEO",)
DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API"
FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax"
API_NODE = True
OUTPUT_NODE = True
class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"subject": (
IO.IMAGE,
{
"tooltip": "Image of subject to reference video generation"
},
),
"prompt_text": (
"STRING",
{
"multiline": True,
"default": "",
"tooltip": "Text prompt to guide the video generation",
},
),
"model": (
[
"S2V-01",
],
{
"default": "S2V-01",
"tooltip": "Model to use for video generation",
},
),
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
RETURN_TYPES = ("VIDEO",)
DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API"
FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax"
API_NODE = True
OUTPUT_NODE = True
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"MinimaxTextToVideoNode": MinimaxTextToVideoNode,
"MinimaxImageToVideoNode": MinimaxImageToVideoNode,
# "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"MinimaxTextToVideoNode": "MiniMax Text to Video",
"MinimaxImageToVideoNode": "MiniMax Image to Video",
"MinimaxSubjectToVideoNode": "MiniMax Subject to Video",
}

View File

@ -0,0 +1,496 @@
import io
from inspect import cleandoc
import numpy as np
import torch
from PIL import Image
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from comfy_api_nodes.apis import (
OpenAIImageGenerationRequest,
OpenAIImageEditRequest,
OpenAIImageGenerationResponse,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import (
downscale_image_tensor,
validate_and_cast_response,
validate_string,
)
class OpenAIDalle2(ComfyNodeABC):
"""
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text prompt for DALL·E",
},
),
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2**31 - 1,
"step": 1,
"display": "number",
"control_after_generate": True,
"tooltip": "not implemented yet in backend",
},
),
"size": (
IO.COMBO,
{
"options": ["256x256", "512x512", "1024x1024"],
"default": "1024x1024",
"tooltip": "Image size",
},
),
"n": (
IO.INT,
{
"default": 1,
"min": 1,
"max": 8,
"step": 1,
"display": "number",
"tooltip": "How many images to generate",
},
),
"image": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional reference image for image editing.",
},
),
"mask": (
IO.MASK,
{
"default": None,
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node/image/OpenAI"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
def api_call(
self,
prompt,
seed=0,
image=None,
mask=None,
n=1,
size="1024x1024",
**kwargs
):
validate_string(prompt, strip_whitespace=False)
model = "dall-e-2"
path = "/proxy/openai/images/generations"
content_type = "application/json"
request_class = OpenAIImageGenerationRequest
img_binary = None
if image is not None and mask is not None:
path = "/proxy/openai/images/edits"
content_type = "multipart/form-data"
request_class = OpenAIImageEditRequest
input_tensor = image.squeeze().cpu()
height, width, channels = input_tensor.shape
rgba_tensor = torch.ones(height, width, 4, device="cpu")
rgba_tensor[:, :, :channels] = input_tensor
if mask.shape[1:] != image.shape[1:-1]:
raise Exception("Mask and Image must be the same size")
rgba_tensor[:, :, 3] = 1 - mask.squeeze().cpu()
rgba_tensor = downscale_image_tensor(rgba_tensor.unsqueeze(0)).squeeze()
image_np = (rgba_tensor.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
img_binary = img_byte_arr # .getvalue()
img_binary.name = "image.png"
elif image is not None or mask is not None:
raise Exception("Dall-E 2 image editing requires an image AND a mask")
# Build the operation
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=request_class,
response_model=OpenAIImageGenerationResponse,
),
request=request_class(
model=model,
prompt=prompt,
n=n,
size=size,
seed=seed,
),
files=(
{
"image": img_binary,
}
if img_binary
else None
),
content_type=content_type,
auth_kwargs=kwargs,
)
response = operation.execute()
img_tensor = validate_and_cast_response(response)
return (img_tensor,)
class OpenAIDalle3(ComfyNodeABC):
"""
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text prompt for DALL·E",
},
),
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2**31 - 1,
"step": 1,
"display": "number",
"control_after_generate": True,
"tooltip": "not implemented yet in backend",
},
),
"quality": (
IO.COMBO,
{
"options": ["standard", "hd"],
"default": "standard",
"tooltip": "Image quality",
},
),
"style": (
IO.COMBO,
{
"options": ["natural", "vivid"],
"default": "natural",
"tooltip": "Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.",
},
),
"size": (
IO.COMBO,
{
"options": ["1024x1024", "1024x1792", "1792x1024"],
"default": "1024x1024",
"tooltip": "Image size",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node/image/OpenAI"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
def api_call(
self,
prompt,
seed=0,
style="natural",
quality="standard",
size="1024x1024",
**kwargs
):
validate_string(prompt, strip_whitespace=False)
model = "dall-e-3"
# build the operation
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/openai/images/generations",
method=HttpMethod.POST,
request_model=OpenAIImageGenerationRequest,
response_model=OpenAIImageGenerationResponse,
),
request=OpenAIImageGenerationRequest(
model=model,
prompt=prompt,
quality=quality,
size=size,
style=style,
seed=seed,
),
auth_kwargs=kwargs,
)
response = operation.execute()
img_tensor = validate_and_cast_response(response)
return (img_tensor,)
class OpenAIGPTImage1(ComfyNodeABC):
"""
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text prompt for GPT Image 1",
},
),
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2**31 - 1,
"step": 1,
"display": "number",
"control_after_generate": True,
"tooltip": "not implemented yet in backend",
},
),
"quality": (
IO.COMBO,
{
"options": ["low", "medium", "high"],
"default": "low",
"tooltip": "Image quality, affects cost and generation time.",
},
),
"background": (
IO.COMBO,
{
"options": ["opaque", "transparent"],
"default": "opaque",
"tooltip": "Return image with or without background",
},
),
"size": (
IO.COMBO,
{
"options": ["auto", "1024x1024", "1024x1536", "1536x1024"],
"default": "auto",
"tooltip": "Image size",
},
),
"n": (
IO.INT,
{
"default": 1,
"min": 1,
"max": 8,
"step": 1,
"display": "number",
"tooltip": "How many images to generate",
},
),
"image": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional reference image for image editing.",
},
),
"mask": (
IO.MASK,
{
"default": None,
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
RETURN_TYPES = (IO.IMAGE,)
FUNCTION = "api_call"
CATEGORY = "api node/image/OpenAI"
DESCRIPTION = cleandoc(__doc__ or "")
API_NODE = True
def api_call(
self,
prompt,
seed=0,
quality="low",
background="opaque",
image=None,
mask=None,
n=1,
size="1024x1024",
**kwargs
):
validate_string(prompt, strip_whitespace=False)
model = "gpt-image-1"
path = "/proxy/openai/images/generations"
content_type="application/json"
request_class = OpenAIImageGenerationRequest
img_binaries = []
mask_binary = None
files = []
if image is not None:
path = "/proxy/openai/images/edits"
request_class = OpenAIImageEditRequest
content_type ="multipart/form-data"
batch_size = image.shape[0]
for i in range(batch_size):
single_image = image[i : i + 1]
scaled_image = downscale_image_tensor(single_image).squeeze()
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
img_binary = img_byte_arr
img_binary.name = f"image_{i}.png"
img_binaries.append(img_binary)
if batch_size == 1:
files.append(("image", img_binary))
else:
files.append(("image[]", img_binary))
if mask is not None:
if image is None:
raise Exception("Cannot use a mask without an input image")
if image.shape[0] != 1:
raise Exception("Cannot use a mask with multiple image")
if mask.shape[1:] != image.shape[1:-1]:
raise Exception("Mask and Image must be the same size")
batch, height, width = mask.shape
rgba_mask = torch.zeros(height, width, 4, device="cpu")
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze()
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
mask_img_byte_arr = io.BytesIO()
mask_img.save(mask_img_byte_arr, format="PNG")
mask_img_byte_arr.seek(0)
mask_binary = mask_img_byte_arr
mask_binary.name = "mask.png"
files.append(("mask", mask_binary))
# Build the operation
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=request_class,
response_model=OpenAIImageGenerationResponse,
),
request=request_class(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=n,
seed=seed,
size=size,
),
files=files if files else None,
content_type=content_type,
auth_kwargs=kwargs,
)
response = operation.execute()
img_tensor = validate_and_cast_response(response)
return (img_tensor,)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"OpenAIDalle2": OpenAIDalle2,
"OpenAIDalle3": OpenAIDalle3,
"OpenAIGPTImage1": OpenAIGPTImage1,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"OpenAIDalle2": "OpenAI DALL·E 2",
"OpenAIDalle3": "OpenAI DALL·E 3",
"OpenAIGPTImage1": "OpenAI GPT Image 1",
}

View File

@ -0,0 +1,757 @@
"""
Pika x ComfyUI API Nodes
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
"""
from __future__ import annotations
import io
from typing import Optional, TypeVar
import logging
import torch
import numpy as np
from comfy_api_nodes.apis import (
PikaBodyGenerate22T2vGenerate22T2vPost,
PikaGenerateResponse,
PikaBodyGenerate22I2vGenerate22I2vPost,
PikaVideoResponse,
PikaBodyGenerate22C2vGenerate22PikascenesPost,
IngredientsMode,
PikaDurationEnum,
PikaResolutionEnum,
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
Pikaffect,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
tensor_to_bytesio,
download_url_to_video_output,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api.input_impl.video_types import VideoInput, VideoContainer, VideoCodec
from comfy_api.input_impl import VideoFromFile
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
R = TypeVar("R")
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
PIKA_API_VERSION = "2.2"
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
PATH_VIDEO_GET = "/proxy/pika/videos"
class PikaApiError(Exception):
"""Exception for Pika API errors."""
pass
def is_valid_video_response(response: PikaVideoResponse) -> bool:
"""Check if the video response is valid."""
return hasattr(response, "url") and response.url is not None
def is_valid_initial_response(response: PikaGenerateResponse) -> bool:
"""Check if the initial response is valid."""
return hasattr(response, "video_id") and response.video_id is not None
class PikaNodeBase(ComfyNodeABC):
"""Base class for Pika nodes."""
@classmethod
def get_base_inputs_types(
cls, request_model
) -> dict[str, tuple[IO, InputTypeOptions]]:
"""Get the base required inputs types common to all Pika nodes."""
return {
"prompt_text": model_field_to_node_input(
IO.STRING,
request_model,
"promptText",
multiline=True,
),
"negative_prompt": model_field_to_node_input(
IO.STRING,
request_model,
"negativePrompt",
multiline=True,
),
"seed": model_field_to_node_input(
IO.INT,
request_model,
"seed",
min=0,
max=0xFFFFFFFF,
control_after_generate=True,
),
"resolution": model_field_to_node_input(
IO.COMBO,
request_model,
"resolution",
enum_type=PikaResolutionEnum,
),
"duration": model_field_to_node_input(
IO.COMBO,
request_model,
"duration",
enum_type=PikaDurationEnum,
),
}
CATEGORY = "api node/video/Pika"
API_NODE = True
FUNCTION = "api_call"
RETURN_TYPES = ("VIDEO",)
def poll_for_task_status(
self, task_id: str, auth_kwargs: Optional[dict[str,str]] = None
) -> PikaGenerateResponse:
polling_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"{PATH_VIDEO_GET}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PikaVideoResponse,
),
completed_statuses=[
"finished",
],
failed_statuses=["failed", "cancelled"],
status_extractor=lambda response: (
response.status.value if response.status else None
),
progress_extractor=lambda response: (
response.progress if hasattr(response, "progress") else None
),
auth_kwargs=auth_kwargs,
)
return polling_operation.execute()
def execute_task(
self,
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
auth_kwargs: Optional[dict[str,str]] = None,
) -> tuple[VideoFromFile]:
"""Executes the initial operation then polls for the task status until it is completed.
Args:
initial_operation: The initial operation to execute.
auth_kwargs: The authentication token(s) to use for the API call.
Returns:
A tuple containing the video file as a VIDEO output.
"""
initial_response = initial_operation.execute()
if not is_valid_initial_response(initial_response):
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
logging.error(error_msg)
raise PikaApiError(error_msg)
task_id = initial_response.video_id
final_response = self.poll_for_task_status(task_id, auth_kwargs)
if not is_valid_video_response(final_response):
error_msg = (
f"Pika task {task_id} succeeded but no video data found in response."
)
logging.error(error_msg)
raise PikaApiError(error_msg)
video_url = str(final_response.url)
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
return (download_url_to_video_output(video_url),)
class PikaImageToVideoV2_2(PikaNodeBase):
"""Pika 2.2 Image to Video Node."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": (
IO.IMAGE,
{"tooltip": "The image to convert to video"},
),
**cls.get_base_inputs_types(PikaBodyGenerate22I2vGenerate22I2vPost),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video."
def api_call(
self,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
**kwargs
) -> tuple[VideoFromFile]:
# Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
# Prepare non-file data
pika_request_data = PikaBodyGenerate22I2vGenerate22I2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=PikaBodyGenerate22I2vGenerate22I2vPost,
response_model=PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs)
class PikaTextToVideoNodeV2_2(PikaNodeBase):
"""Pika Text2Video v2.2 Node."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
**cls.get_base_inputs_types(PikaBodyGenerate22T2vGenerate22T2vPost),
"aspect_ratio": model_field_to_node_input(
IO.FLOAT,
PikaBodyGenerate22T2vGenerate22T2vPost,
"aspectRatio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video."
def api_call(
self,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
aspect_ratio: float,
**kwargs,
) -> tuple[VideoFromFile]:
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_VIDEO,
method=HttpMethod.POST,
request_model=PikaBodyGenerate22T2vGenerate22T2vPost,
response_model=PikaGenerateResponse,
),
request=PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
),
auth_kwargs=kwargs,
content_type="application/x-www-form-urlencoded",
)
return self.execute_task(initial_operation, auth_kwargs=kwargs)
class PikaScenesV2_2(PikaNodeBase):
"""PikaScenes v2.2 Node."""
@classmethod
def INPUT_TYPES(cls):
image_ingredient_input = (
IO.IMAGE,
{"tooltip": "Image that will be used as ingredient to create a video."},
)
return {
"required": {
**cls.get_base_inputs_types(
PikaBodyGenerate22C2vGenerate22PikascenesPost,
),
"ingredients_mode": model_field_to_node_input(
IO.COMBO,
PikaBodyGenerate22C2vGenerate22PikascenesPost,
"ingredientsMode",
enum_type=IngredientsMode,
default="creative",
),
"aspect_ratio": model_field_to_node_input(
IO.FLOAT,
PikaBodyGenerate22C2vGenerate22PikascenesPost,
"aspectRatio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
),
},
"optional": {
"image_ingredient_1": image_ingredient_input,
"image_ingredient_2": image_ingredient_input,
"image_ingredient_3": image_ingredient_input,
"image_ingredient_4": image_ingredient_input,
"image_ingredient_5": image_ingredient_input,
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them."
def api_call(
self,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
ingredients_mode: str,
aspect_ratio: float,
image_ingredient_1: Optional[torch.Tensor] = None,
image_ingredient_2: Optional[torch.Tensor] = None,
image_ingredient_3: Optional[torch.Tensor] = None,
image_ingredient_4: Optional[torch.Tensor] = None,
image_ingredient_5: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Convert all passed images to BytesIO
all_image_bytes_io = []
for image in [
image_ingredient_1,
image_ingredient_2,
image_ingredient_3,
image_ingredient_4,
image_ingredient_5,
]:
if image is not None:
image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
all_image_bytes_io.append(image_bytes_io)
pika_files = [
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
for i, image_bytes_io in enumerate(all_image_bytes_io)
]
pika_request_data = PikaBodyGenerate22C2vGenerate22PikascenesPost(
ingredientsMode=ingredients_mode,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKASCENES,
method=HttpMethod.POST,
request_model=PikaBodyGenerate22C2vGenerate22PikascenesPost,
response_model=PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs)
class PikAdditionsNode(PikaNodeBase):
"""Pika Pikadditions Node. Add an image into a video."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video": (IO.VIDEO, {"tooltip": "The video to add an image to."}),
"image": (IO.IMAGE, {"tooltip": "The image to add to the video."}),
"prompt_text": model_field_to_node_input(
IO.STRING,
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
"promptText",
multiline=True,
),
"negative_prompt": model_field_to_node_input(
IO.STRING,
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
"negativePrompt",
multiline=True,
),
"seed": model_field_to_node_input(
IO.INT,
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
"seed",
min=0,
max=0xFFFFFFFF,
control_after_generate=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what youd like to add to create a seamlessly integrated result."
def api_call(
self,
video: VideoInput,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
**kwargs,
) -> tuple[VideoFromFile]:
# Convert video to BytesIO
video_bytes_io = io.BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
# Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = [
("video", ("video.mp4", video_bytes_io, "video/mp4")),
("image", ("image.png", image_bytes_io, "image/png")),
]
# Prepare non-file data
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKADDITIONS,
method=HttpMethod.POST,
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
response_model=PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs)
class PikaSwapsNode(PikaNodeBase):
"""Pika Pikaswaps Node."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video": (IO.VIDEO, {"tooltip": "The video to swap an object in."}),
"image": (
IO.IMAGE,
{
"tooltip": "The image used to replace the masked object in the video."
},
),
"mask": (
IO.MASK,
{"tooltip": "Use the mask to define areas in the video to replace"},
),
"prompt_text": model_field_to_node_input(
IO.STRING,
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
"promptText",
multiline=True,
),
"negative_prompt": model_field_to_node_input(
IO.STRING,
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
"negativePrompt",
multiline=True,
),
"seed": model_field_to_node_input(
IO.INT,
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
"seed",
min=0,
max=0xFFFFFFFF,
control_after_generate=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates."
RETURN_TYPES = ("VIDEO",)
def api_call(
self,
video: VideoInput,
image: torch.Tensor,
mask: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
**kwargs,
) -> tuple[VideoFromFile]:
# Convert video to BytesIO
video_bytes_io = io.BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
# Convert mask to binary mask with three channels
mask = torch.round(mask)
mask = mask.repeat(1, 3, 1, 1)
# Convert 3-channel binary mask to BytesIO
mask_bytes_io = io.BytesIO()
mask_bytes_io.write(mask.numpy().astype(np.uint8))
mask_bytes_io.seek(0)
# Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = [
("video", ("video.mp4", video_bytes_io, "video/mp4")),
("image", ("image.png", image_bytes_io, "image/png")),
("modifyRegionMask", ("mask.png", mask_bytes_io, "image/png")),
]
# Prepare non-file data
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKADDITIONS,
method=HttpMethod.POST,
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
response_model=PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs)
class PikaffectsNode(PikaNodeBase):
"""Pika Pikaffects Node."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": (
IO.IMAGE,
{"tooltip": "The reference image to apply the Pikaffect to."},
),
"pikaffect": model_field_to_node_input(
IO.COMBO,
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
"pikaffect",
enum_type=Pikaffect,
default="Cake-ify",
),
"prompt_text": model_field_to_node_input(
IO.STRING,
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
"promptText",
multiline=True,
),
"negative_prompt": model_field_to_node_input(
IO.STRING,
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
"negativePrompt",
multiline=True,
),
"seed": model_field_to_node_input(
IO.INT,
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
"seed",
min=0,
max=0xFFFFFFFF,
control_after_generate=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear"
def api_call(
self,
image: torch.Tensor,
pikaffect: str,
prompt_text: str,
negative_prompt: str,
seed: int,
**kwargs,
) -> tuple[VideoFromFile]:
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKAFFECTS,
method=HttpMethod.POST,
request_model=PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
response_model=PikaGenerateResponse,
),
request=PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs)
class PikaStartEndFrameNode2_2(PikaNodeBase):
"""PikaFrames v2.2 Node."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image_start": (IO.IMAGE, {"tooltip": "The first image to combine."}),
"image_end": (IO.IMAGE, {"tooltip": "The last image to combine."}),
**cls.get_base_inputs_types(
PikaBodyGenerate22KeyframeGenerate22PikaframesPost
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them."
def api_call(
self,
image_start: torch.Tensor,
image_end: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
**kwargs,
) -> tuple[VideoFromFile]:
pika_files = [
(
"keyFrames",
("image_start.png", tensor_to_bytesio(image_start), "image/png"),
),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
]
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKAFRAMES,
method=HttpMethod.POST,
request_model=PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
response_model=PikaGenerateResponse,
),
request=PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
),
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs)
NODE_CLASS_MAPPINGS = {
"PikaImageToVideoNode2_2": PikaImageToVideoV2_2,
"PikaTextToVideoNode2_2": PikaTextToVideoNodeV2_2,
"PikaScenesV2_2": PikaScenesV2_2,
"Pikadditions": PikAdditionsNode,
"Pikaswaps": PikaSwapsNode,
"Pikaffects": PikaffectsNode,
"PikaStartEndFrameNode2_2": PikaStartEndFrameNode2_2,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PikaImageToVideoNode2_2": "Pika Image to Video",
"PikaTextToVideoNode2_2": "Pika Text to Video",
"PikaScenesV2_2": "Pika Scenes (Video Image Composition)",
"Pikadditions": "Pikadditions (Video Object Insertion)",
"Pikaswaps": "Pika Swaps (Video Object Replacement)",
"Pikaffects": "Pikaffects (Video Effects)",
"PikaStartEndFrameNode2_2": "Pika Start and End Frame to Video",
}

View File

@ -0,0 +1,492 @@
from inspect import cleandoc
from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest,
PixverseImageVideoRequest,
PixverseTransitionVideoRequest,
PixverseImageUploadResponse,
PixverseVideoResponse,
PixverseGenerationStatusResponse,
PixverseAspectRatio,
PixverseQuality,
PixverseDuration,
PixverseMotionMode,
PixverseStatus,
PixverseIO,
pixverse_templates,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
tensor_to_bytesio,
validate_string,
)
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl import VideoFromFile
import torch
import requests
from io import BytesIO
def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
# first, upload image to Pixverse and get image id to use in actual generation call
files = {
"image": tensor_to_bytesio(image)
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/image/upload",
method=HttpMethod.POST,
request_model=EmptyRequest,
response_model=PixverseImageUploadResponse,
),
request=EmptyRequest(),
files=files,
content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
)
response_upload: PixverseImageUploadResponse = operation.execute()
if response_upload.Resp is None:
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
return response_upload.Resp.img_id
class PixverseTemplateNode:
"""
Select template for PixVerse Video generation.
"""
RETURN_TYPES = (PixverseIO.TEMPLATE,)
RETURN_NAMES = ("pixverse_template",)
FUNCTION = "create_template"
CATEGORY = "api node/video/PixVerse"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"template": (list(pixverse_templates.keys()), ),
}
}
def create_template(self, template: str):
template_id = pixverse_templates.get(template, None)
if template_id is None:
raise Exception(f"Template '{template}' is not recognized.")
# just return the integer
return (template_id,)
class PixverseTextToVideoNode(ComfyNodeABC):
"""
Generates videos synchronously based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"aspect_ratio": (
[ratio.value for ratio in PixverseAspectRatio],
),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
"pixverse_template": (
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
}
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
prompt: str,
aspect_ratio: str,
quality: str,
duration_seconds: int,
motion_mode: str,
seed,
negative_prompt: str=None,
pixverse_template: int=None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
if quality == PixverseQuality.res_1080p:
motion_mode = PixverseMotionMode.normal
duration_seconds = PixverseDuration.dur_5
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/text/generate",
method=HttpMethod.POST,
request_model=PixverseTextVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseTextVideoRequest(
prompt=prompt,
aspect_ratio=aspect_ratio,
quality=quality,
duration=duration_seconds,
motion_mode=motion_mode,
negative_prompt=negative_prompt if negative_prompt else None,
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=kwargs,
)
response_api = operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful],
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
vid_response = requests.get(response_poll.Resp.url)
return (VideoFromFile(BytesIO(vid_response.content)),)
class PixverseImageToVideoNode(ComfyNodeABC):
"""
Generates videos synchronously based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (
IO.IMAGE,
),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
"pixverse_template": (
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
}
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
image: torch.Tensor,
prompt: str,
quality: str,
duration_seconds: int,
motion_mode: str,
seed,
negative_prompt: str=None,
pixverse_template: int=None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
if quality == PixverseQuality.res_1080p:
motion_mode = PixverseMotionMode.normal
duration_seconds = PixverseDuration.dur_5
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/img/generate",
method=HttpMethod.POST,
request_model=PixverseImageVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseImageVideoRequest(
img_id=img_id,
prompt=prompt,
quality=quality,
duration=duration_seconds,
motion_mode=motion_mode,
negative_prompt=negative_prompt if negative_prompt else None,
template_id=pixverse_template,
seed=seed,
),
auth_kwargs=kwargs,
)
response_api = operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful],
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
vid_response = requests.get(response_poll.Resp.url)
return (VideoFromFile(BytesIO(vid_response.content)),)
class PixverseTransitionVideoNode(ComfyNodeABC):
"""
Generates videos synchronously based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/video/PixVerse"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"first_frame": (
IO.IMAGE,
),
"last_frame": (
IO.IMAGE,
),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the video generation",
},
),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
"default": PixverseQuality.res_540p,
},
),
"duration_seconds": ([dur.value for dur in PixverseDuration],),
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 2147483647,
"control_after_generate": True,
"tooltip": "Seed for video generation.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "An optional text description of undesired elements on an image.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
prompt: str,
quality: str,
duration_seconds: int,
motion_mode: str,
seed,
negative_prompt: str=None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
# 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration
if quality == PixverseQuality.res_1080p:
motion_mode = PixverseMotionMode.normal
duration_seconds = PixverseDuration.dur_5
elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/transition/generate",
method=HttpMethod.POST,
request_model=PixverseTransitionVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseTransitionVideoRequest(
first_frame_img=first_frame_id,
last_frame_img=last_frame_id,
prompt=prompt,
quality=quality,
duration=duration_seconds,
motion_mode=motion_mode,
negative_prompt=negative_prompt if negative_prompt else None,
seed=seed,
),
auth_kwargs=kwargs,
)
response_api = operation.execute()
if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful],
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
vid_response = requests.get(response_poll.Resp.url)
return (VideoFromFile(BytesIO(vid_response.content)),)
NODE_CLASS_MAPPINGS = {
"PixverseTextToVideoNode": PixverseTextToVideoNode,
"PixverseImageToVideoNode": PixverseImageToVideoNode,
"PixverseTransitionVideoNode": PixverseTransitionVideoNode,
"PixverseTemplateNode": PixverseTemplateNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PixverseTextToVideoNode": "PixVerse Text to Video",
"PixverseImageToVideoNode": "PixVerse Image to Video",
"PixverseTransitionVideoNode": "PixVerse Transition Video",
"PixverseTemplateNode": "PixVerse Template",
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,614 @@
from inspect import cleandoc
from comfy.comfy_types.node_typing import IO
from comfy_api_nodes.apis.stability_api import (
StabilityUpscaleConservativeRequest,
StabilityUpscaleCreativeRequest,
StabilityAsyncResponse,
StabilityResultsGetResponse,
StabilityStable3_5Request,
StabilityStableUltraRequest,
StabilityStableUltraResponse,
StabilityAspectRatio,
Stability_SD3_5_Model,
Stability_SD3_5_GenerationMode,
get_stability_style_presets,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
bytesio_to_image_tensor,
tensor_to_bytesio,
validate_string,
)
import torch
import base64
from io import BytesIO
from enum import Enum
class StabilityPollStatus(str, Enum):
finished = "finished"
in_progress = "in_progress"
failed = "failed"
def get_async_dummy_status(x: StabilityResultsGetResponse):
if x.name is not None or x.errors is not None:
return StabilityPollStatus.failed
elif x.finish_reason is not None:
return StabilityPollStatus.finished
return StabilityPollStatus.in_progress
class StabilityStableImageUltraNode:
"""
Generates images synchronously based on prompt and resolution.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Stability AI"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
"What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
"elements, colors, and subjects will lead to better results. " +
"To control the weight of a given word use the format `(word:weight)`," +
"where `word` is the word you'd like to control the weight of and `weight`" +
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
"would convey a sky that was blue and green, but more green than blue."
},
),
"aspect_ratio": ([x.value for x in StabilityAspectRatio],
{
"default": StabilityAspectRatio.ratio_1_1,
"tooltip": "Aspect ratio of generated image.",
},
),
"style_preset": (get_stability_style_presets(),
{
"tooltip": "Optional desired style of generated image.",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 4294967294,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"optional": {
"image": (IO.IMAGE,),
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "A blurb of text describing what you do not wish to see in the output image. This is an advanced feature."
},
),
"image_denoise": (
IO.FLOAT,
{
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
**kwargs):
validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present
image_binary = None
if image is not None:
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
else:
image_denoise = None
if not negative_prompt:
negative_prompt = None
if style_preset == "None":
style_preset = None
files = {
"image": image_binary
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/ultra",
method=HttpMethod.POST,
request_model=StabilityStableUltraRequest,
response_model=StabilityStableUltraResponse,
),
request=StabilityStableUltraRequest(
prompt=prompt,
negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio,
seed=seed,
strength=image_denoise,
style_preset=style_preset,
),
files=files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
response_api = operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return (returned_image,)
class StabilityStableImageSD_3_5Node:
"""
Generates images synchronously based on prompt and resolution.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Stability AI"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
},
),
"model": ([x.value for x in Stability_SD3_5_Model],),
"aspect_ratio": ([x.value for x in StabilityAspectRatio],
{
"default": StabilityAspectRatio.ratio_1_1,
"tooltip": "Aspect ratio of generated image.",
},
),
"style_preset": (get_stability_style_presets(),
{
"tooltip": "Optional desired style of generated image.",
},
),
"cfg_scale": (
IO.FLOAT,
{
"default": 4.0,
"min": 1.0,
"max": 10.0,
"step": 0.1,
"tooltip": "How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 4294967294,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"optional": {
"image": (IO.IMAGE,),
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
},
),
"image_denoise": (
IO.FLOAT,
{
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
**kwargs):
validate_string(prompt, strip_whitespace=False)
# prepare image binary if image present
image_binary = None
mode = Stability_SD3_5_GenerationMode.text_to_image
if image is not None:
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
mode = Stability_SD3_5_GenerationMode.image_to_image
aspect_ratio = None
else:
image_denoise = None
if not negative_prompt:
negative_prompt = None
if style_preset == "None":
style_preset = None
files = {
"image": image_binary
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/sd3",
method=HttpMethod.POST,
request_model=StabilityStable3_5Request,
response_model=StabilityStableUltraResponse,
),
request=StabilityStable3_5Request(
prompt=prompt,
negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio,
seed=seed,
strength=image_denoise,
style_preset=style_preset,
cfg_scale=cfg_scale,
model=model,
mode=mode,
),
files=files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
response_api = operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return (returned_image,)
class StabilityUpscaleConservativeNode:
"""
Upscale image with minimal alterations to 4K resolution.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Stability AI"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
},
),
"creativity": (
IO.FLOAT,
{
"default": 0.35,
"min": 0.2,
"max": 0.5,
"step": 0.01,
"tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 4294967294,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
**kwargs):
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
if not negative_prompt:
negative_prompt = None
files = {
"image": image_binary
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
method=HttpMethod.POST,
request_model=StabilityUpscaleConservativeRequest,
response_model=StabilityStableUltraResponse,
),
request=StabilityUpscaleConservativeRequest(
prompt=prompt,
negative_prompt=negative_prompt,
creativity=round(creativity,2),
seed=seed,
),
files=files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
response_api = operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return (returned_image,)
class StabilityUpscaleCreativeNode:
"""
Upscale image with minimal alterations to 4K resolution.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Stability AI"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
},
),
"creativity": (
IO.FLOAT,
{
"default": 0.3,
"min": 0.1,
"max": 0.5,
"step": 0.01,
"tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.",
},
),
"style_preset": (get_stability_style_presets(),
{
"tooltip": "Optional desired style of generated image.",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 4294967294,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"default": "",
"forceInput": True,
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
**kwargs):
validate_string(prompt, strip_whitespace=False)
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
if not negative_prompt:
negative_prompt = None
if style_preset == "None":
style_preset = None
files = {
"image": image_binary
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/creative",
method=HttpMethod.POST,
request_model=StabilityUpscaleCreativeRequest,
response_model=StabilityAsyncResponse,
),
request=StabilityUpscaleCreativeRequest(
prompt=prompt,
negative_prompt=negative_prompt,
creativity=round(creativity,2),
style_preset=style_preset,
seed=seed,
),
files=files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
response_api = operation.execute()
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/stability/v2beta/results/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=StabilityResultsGetResponse,
),
poll_interval=3,
completed_statuses=[StabilityPollStatus.finished],
failed_statuses=[StabilityPollStatus.failed],
status_extractor=lambda x: get_async_dummy_status(x),
auth_kwargs=kwargs,
)
response_poll: StabilityResultsGetResponse = operation.execute()
if response_poll.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
image_data = base64.b64decode(response_poll.result)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return (returned_image,)
class StabilityUpscaleFastNode:
"""
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
"""
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/Stability AI"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (IO.IMAGE,),
},
"optional": {
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(self, image: torch.Tensor,
**kwargs):
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
files = {
"image": image_binary
}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/fast",
method=HttpMethod.POST,
request_model=EmptyRequest,
response_model=StabilityStableUltraResponse,
),
request=EmptyRequest(),
files=files,
content_type="multipart/form-data",
auth_kwargs=kwargs,
)
response_api = operation.execute()
if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
image_data = base64.b64decode(response_api.image)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
return (returned_image,)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"StabilityStableImageUltraNode": StabilityStableImageUltraNode,
"StabilityStableImageSD_3_5Node": StabilityStableImageSD_3_5Node,
"StabilityUpscaleConservativeNode": StabilityUpscaleConservativeNode,
"StabilityUpscaleCreativeNode": StabilityUpscaleCreativeNode,
"StabilityUpscaleFastNode": StabilityUpscaleFastNode,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"StabilityStableImageUltraNode": "Stability AI Stable Image Ultra",
"StabilityStableImageSD_3_5Node": "Stability AI Stable Diffusion 3.5 Image",
"StabilityUpscaleConservativeNode": "Stability AI Upscale Conservative",
"StabilityUpscaleCreativeNode": "Stability AI Upscale Creative",
"StabilityUpscaleFastNode": "Stability AI Upscale Fast",
}

View File

@ -0,0 +1,284 @@
import io
import logging
import base64
import requests
import torch
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
Veo2GenVidRequest,
Veo2GenVidResponse,
Veo2GenVidPollRequest,
Veo2GenVidPollResponse
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
)
from comfy_api_nodes.apinode_utils import (
downscale_image_tensor,
tensor_to_base64_string
)
def convert_image_to_base64(image: torch.Tensor):
if image is None:
return None
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
return tensor_to_base64_string(scaled_image)
class VeoVideoGenerationNode(ComfyNodeABC):
"""
Generates videos from text prompts using Google's Veo API.
This node can create videos from text descriptions and optional image inputs,
with control over parameters like aspect ratio, duration, and more.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text description of the video",
},
),
"aspect_ratio": (
IO.COMBO,
{
"options": ["16:9", "9:16"],
"default": "16:9",
"tooltip": "Aspect ratio of the output video",
},
),
},
"optional": {
"negative_prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Negative text prompt to guide what to avoid in the video",
},
),
"duration_seconds": (
IO.INT,
{
"default": 5,
"min": 5,
"max": 8,
"step": 1,
"display": "number",
"tooltip": "Duration of the output video in seconds",
},
),
"enhance_prompt": (
IO.BOOLEAN,
{
"default": True,
"tooltip": "Whether to enhance the prompt with AI assistance",
}
),
"person_generation": (
IO.COMBO,
{
"options": ["ALLOW", "BLOCK"],
"default": "ALLOW",
"tooltip": "Whether to allow generating people in the video",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFF,
"step": 1,
"display": "number",
"control_after_generate": True,
"tooltip": "Seed for video generation (0 for random)",
},
),
"image": (IO.IMAGE, {
"default": None,
"tooltip": "Optional reference image to guide video generation",
}),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
RETURN_TYPES = (IO.VIDEO,)
FUNCTION = "generate_video"
CATEGORY = "api node/video/Veo"
DESCRIPTION = "Generates videos from text prompts using Google's Veo API"
API_NODE = True
def generate_video(
self,
prompt,
aspect_ratio="16:9",
negative_prompt="",
duration_seconds=5,
enhance_prompt=True,
person_generation="ALLOW",
seed=0,
image=None,
**kwargs,
):
# Prepare the instances for the request
instances = []
instance = {
"prompt": prompt
}
# Add image if provided
if image is not None:
image_base64 = convert_image_to_base64(image)
if image_base64:
instance["image"] = {
"bytesBase64Encoded": image_base64,
"mimeType": "image/png"
}
instances.append(instance)
# Create parameters dictionary
parameters = {
"aspectRatio": aspect_ratio,
"personGeneration": person_generation,
"durationSeconds": duration_seconds,
"enhancePrompt": enhance_prompt,
}
# Add optional parameters if provided
if negative_prompt:
parameters["negativePrompt"] = negative_prompt
if seed > 0:
parameters["seed"] = seed
# Initial request to start video generation
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/veo/generate",
method=HttpMethod.POST,
request_model=Veo2GenVidRequest,
response_model=Veo2GenVidResponse
),
request=Veo2GenVidRequest(
instances=instances,
parameters=parameters
),
auth_kwargs=kwargs,
)
initial_response = initial_operation.execute()
operation_name = initial_response.name
logging.info(f"Veo generation started with operation name: {operation_name}")
# Define status extractor function
def status_extractor(response):
# Only return "completed" if the operation is done, regardless of success or failure
# We'll check for errors after polling completes
return "completed" if response.done else "pending"
# Define progress extractor function
def progress_extractor(response):
# Could be enhanced if the API provides progress information
return None
# Define the polling operation
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/veo/poll",
method=HttpMethod.POST,
request_model=Veo2GenVidPollRequest,
response_model=Veo2GenVidPollResponse
),
completed_statuses=["completed"],
failed_statuses=[], # No failed statuses, we'll handle errors after polling
status_extractor=status_extractor,
progress_extractor=progress_extractor,
request=Veo2GenVidPollRequest(
operationName=operation_name
),
auth_kwargs=kwargs,
poll_interval=5.0
)
# Execute the polling operation
poll_response = poll_operation.execute()
# Now check for errors in the final response
# Check for error in poll response
if hasattr(poll_response, 'error') and poll_response.error:
error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})"
logging.error(error_message)
raise Exception(error_message)
# Check for RAI filtered content
if (hasattr(poll_response.response, 'raiMediaFilteredCount') and
poll_response.response.raiMediaFilteredCount > 0):
# Extract reason message if available
if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and
poll_response.response.raiMediaFilteredReasons):
reason = poll_response.response.raiMediaFilteredReasons[0]
error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
else:
error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
logging.error(error_message)
raise Exception(error_message)
# Extract video data
video_data = None
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
video = poll_response.response.videos[0]
# Check if video is provided as base64 or URL
if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded:
# Decode base64 string to bytes
video_data = base64.b64decode(video.bytesBase64Encoded)
elif hasattr(video, 'gcsUri') and video.gcsUri:
# Download from URL
video_url = video.gcsUri
video_response = requests.get(video_url)
video_data = video_response.content
else:
raise Exception("Video returned but no data or URL was provided")
else:
raise Exception("Video generation completed but no video was returned")
if not video_data:
raise Exception("No video data was returned")
logging.info("Video generation completed successfully")
# Convert video data to BytesIO object
video_io = io.BytesIO(video_data)
# Return VideoFromFile object
return (VideoFromFile(video_io),)
# Register the node
NODE_CLASS_MAPPINGS = {
"VeoVideoGenerationNode": VeoVideoGenerationNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"VeoVideoGenerationNode": "Google Veo2 Video Generation",
}

View File

@ -0,0 +1,10 @@
# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes.
# This is used for development purposes to generate stubs for unreleased API endpoints.
apis:
filter:
root: openapi.yaml
decorators:
filter-in:
property: tags
value: ['API Nodes']
matchStrategy: all

View File

@ -0,0 +1,10 @@
# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes.
apis:
filter:
root: openapi.yaml
decorators:
filter-in:
property: tags
value: ['API Nodes', 'Released']
matchStrategy: all

49
comfy_extras/nodes_ace.py Normal file
View File

@ -0,0 +1,49 @@
import torch
import comfy.model_management
import node_helpers
class TextEncodeAceStepAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "conditioning"
def encode(self, clip, tags, lyrics, lyrics_strength):
tokens = clip.tokenize(tags, lyrics=lyrics)
conditioning = clip.encode_from_tokens_scheduled(tokens)
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
return (conditioning, )
class EmptyAceStepLatentAudio:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod
def INPUT_TYPES(s):
return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/audio"
def generate(self, seconds, batch_size):
length = int(seconds * 44100 / 512 / 8)
latent = torch.zeros([batch_size, 8, 16, length], device=self.device)
return ({"samples": latent, "type": "audio"}, )
NODE_CLASS_MAPPINGS = {
"TextEncodeAceStepAudio": TextEncodeAceStepAudio,
"EmptyAceStepLatentAudio": EmptyAceStepLatentAudio,
}

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import av
import torchaudio
import torch
import comfy.model_management
@ -7,7 +8,6 @@ import folder_paths
import os
import io
import json
import struct
import random
import hashlib
import node_helpers
@ -90,60 +90,118 @@ class VAEDecodeAudio:
return ({"waveform": audio, "sample_rate": 44100}, )
def create_vorbis_comment_block(comment_dict, last_block):
vendor_string = b'ComfyUI'
vendor_length = len(vendor_string)
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
comments = []
for key, value in comment_dict.items():
comment = f"{key}={value}".encode('utf-8')
comments.append(struct.pack('<I', len(comment)) + comment)
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results: list[FileLocator] = []
user_comment_list_length = len(comments)
user_comments = b''.join(comments)
# Prepare metadata dictionary
metadata = {}
if not args.disable_metadata:
if prompt is not None:
metadata["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
comment_data = struct.pack('<I', vendor_length) + vendor_string + struct.pack('<I', user_comment_list_length) + user_comments
if last_block:
id = b'\x84'
else:
id = b'\x04'
comment_block = id + struct.pack('>I', len(comment_data))[1:] + comment_data
# Opus supported sample rates
OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
return comment_block
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
output_path = os.path.join(full_output_folder, file)
def insert_or_replace_vorbis_comment(flac_io, comment_dict):
if len(comment_dict) == 0:
return flac_io
# Use original sample rate initially
sample_rate = audio["sample_rate"]
flac_io.seek(4)
# Handle Opus sample rate requirements
if format == "opus":
if sample_rate > 48000:
sample_rate = 48000
elif sample_rate not in OPUS_RATES:
# Find the next highest supported rate
for rate in sorted(OPUS_RATES):
if rate > sample_rate:
sample_rate = rate
break
if sample_rate not in OPUS_RATES: # Fallback if still not supported
sample_rate = 48000
blocks = []
last_block = False
# Resample if necessary
if sample_rate != audio["sample_rate"]:
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
while not last_block:
header = flac_io.read(4)
last_block = (header[0] & 0x80) != 0
block_type = header[0] & 0x7F
block_length = struct.unpack('>I', b'\x00' + header[1:])[0]
block_data = flac_io.read(block_length)
# Create in-memory WAV buffer
wav_buffer = io.BytesIO()
torchaudio.save(wav_buffer, waveform, sample_rate, format="WAV")
wav_buffer.seek(0) # Rewind for reading
if block_type == 4 or block_type == 1:
pass
else:
header = bytes([(header[0] & (~0x80))]) + header[1:]
blocks.append(header + block_data)
# Use PyAV to convert and add metadata
input_container = av.open(wav_buffer)
blocks.append(create_vorbis_comment_block(comment_dict, last_block=True))
# Create output with specified format
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format=format)
new_flac_io = io.BytesIO()
new_flac_io.write(b'fLaC')
for block in blocks:
new_flac_io.write(block)
# Set metadata on the container
for key, value in metadata.items():
output_container.metadata[key] = value
new_flac_io.write(flac_io.read())
return new_flac_io
# Set up the output stream with appropriate properties
input_container.streams.audio[0]
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
out_stream.bit_rate = 96000
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "192k":
out_stream.bit_rate = 192000
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
if quality == "V0":
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "320k":
out_stream.bit_rate = 320000
else: #format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
# Copy frames from input to output
for frame in input_container.decode(audio=0):
frame.pts = None # Let PyAV handle timestamps
output_container.mux(out_stream.encode(frame))
# Flush encoder
output_container.mux(out_stream.encode(None))
# Close containers
output_container.close()
input_container.close()
# Write the output to file
output_buffer.seek(0)
with open(output_path, 'wb') as f:
f.write(output_buffer.getbuffer())
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "audio": results } }
class SaveAudio:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@ -153,50 +211,70 @@ class SaveAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_audio"
FUNCTION = "save_flac"
OUTPUT_NODE = True
CATEGORY = "audio"
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results: list[FileLocator] = []
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
metadata = {}
if not args.disable_metadata:
if prompt is not None:
metadata["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
class SaveAudioMP3:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.flac"
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
buff = io.BytesIO()
torchaudio.save(buff, waveform, audio["sample_rate"], format="FLAC")
RETURN_TYPES = ()
FUNCTION = "save_mp3"
buff = insert_or_replace_vorbis_comment(buff, metadata)
OUTPUT_NODE = True
with open(os.path.join(full_output_folder, file), 'wb') as f:
f.write(buff.getbuffer())
CATEGORY = "audio"
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
return { "ui": { "audio": results } }
class SaveAudioOpus:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_opus"
OUTPUT_NODE = True
CATEGORY = "audio"
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
class PreviewAudio(SaveAudio):
def __init__(self):
@ -248,7 +326,20 @@ NODE_CLASS_MAPPINGS = {
"VAEEncodeAudio": VAEEncodeAudio,
"VAEDecodeAudio": VAEDecodeAudio,
"SaveAudio": SaveAudio,
"SaveAudioMP3": SaveAudioMP3,
"SaveAudioOpus": SaveAudioOpus,
"LoadAudio": LoadAudio,
"PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"EmptyLatentAudio": "Empty Latent Audio",
"VAEEncodeAudio": "VAE Encode Audio",
"VAEDecodeAudio": "VAE Decode Audio",
"PreviewAudio": "Preview Audio",
"LoadAudio": "Load Audio",
"SaveAudio": "Save Audio (FLAC)",
"SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)",
}

View File

@ -20,6 +20,29 @@ class CLIPTextEncodeControlnet:
c.append(n)
return (c, )
class T5TokenizerOptions:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"clip": ("CLIP", ),
"min_padding": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
"min_length": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
}
}
RETURN_TYPES = ("CLIP",)
FUNCTION = "set_options"
def set_options(self, clip, min_padding, min_length):
clip = clip.clone()
for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]:
clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding)
clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length)
return (clip, )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet,
"T5TokenizerOptions": T5TokenizerOptions,
}

View File

@ -1,3 +1,4 @@
import math
import comfy.samplers
import comfy.sample
from comfy.k_diffusion import sampling as k_diffusion_sampling
@ -249,6 +250,55 @@ class SetFirstSigma:
sigmas[0] = sigma
return (sigmas, )
class ExtendIntermediateSigmas:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"sigmas": ("SIGMAS", ),
"steps": ("INT", {"default": 2, "min": 1, "max": 100}),
"start_at_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20000.0, "step": 0.01, "round": False}),
"end_at_sigma": ("FLOAT", {"default": 12.0, "min": 0.0, "max": 20000.0, "step": 0.01, "round": False}),
"spacing": (['linear', 'cosine', 'sine'],),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "extend"
def extend(self, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str):
if start_at_sigma < 0:
start_at_sigma = float("inf")
interpolator = {
'linear': lambda x: x,
'cosine': lambda x: torch.sin(x*math.pi/2),
'sine': lambda x: 1 - torch.cos(x*math.pi/2)
}[spacing]
# linear space for our interpolation function
x = torch.linspace(0, 1, steps + 1, device=sigmas.device)[1:-1]
computed_spacing = interpolator(x)
extended_sigmas = []
for i in range(len(sigmas) - 1):
sigma_current = sigmas[i]
sigma_next = sigmas[i+1]
extended_sigmas.append(sigma_current)
if end_at_sigma <= sigma_current <= start_at_sigma:
interpolated_steps = computed_spacing * (sigma_next - sigma_current) + sigma_current
extended_sigmas.extend(interpolated_steps.tolist())
# Add the last sigma value
if len(sigmas) > 0:
extended_sigmas.append(sigmas[-1])
extended_sigmas = torch.FloatTensor(extended_sigmas)
return (extended_sigmas,)
class KSamplerSelect:
@classmethod
def INPUT_TYPES(s):
@ -735,6 +785,7 @@ NODE_CLASS_MAPPINGS = {
"SplitSigmasDenoise": SplitSigmasDenoise,
"FlipSigmas": FlipSigmas,
"SetFirstSigma": SetFirstSigma,
"ExtendIntermediateSigmas": ExtendIntermediateSigmas,
"CFGGuider": CFGGuider,
"DualCFGGuider": DualCFGGuider,

View File

@ -10,6 +10,9 @@ from PIL.PngImagePlugin import PngInfo
import numpy as np
import json
import os
import re
from io import BytesIO
from inspect import cleandoc
from comfy.comfy_types import FileLocator
@ -190,10 +193,109 @@ class SaveAnimatedPNG:
return { "ui": { "images": results, "animated": (True,)} }
class SVG:
"""
Stores SVG representations via a list of BytesIO objects.
"""
def __init__(self, data: list[BytesIO]):
self.data = data
def combine(self, other: 'SVG') -> 'SVG':
return SVG(self.data + other.data)
@staticmethod
def combine_all(svgs: list['SVG']) -> 'SVG':
all_svgs_list: list[BytesIO] = []
for svg_item in svgs:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)
class SaveSVGNode:
"""
Save SVG files on disk.
"""
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
RETURN_TYPES = ()
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "save_svg"
CATEGORY = "image/save" # Changed
OUTPUT_NODE = True
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"svg": ("SVG",), # Changed
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
}
}
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results = list()
# Prepare metadata JSON
metadata_dict = {}
if prompt is not None:
metadata_dict["prompt"] = prompt
if extra_pnginfo is not None:
metadata_dict.update(extra_pnginfo)
# Convert metadata to JSON string
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
for batch_number, svg_bytes in enumerate(svg.data):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.svg"
# Read SVG content
svg_bytes.seek(0)
svg_content = svg_bytes.read().decode('utf-8')
# Inject metadata if available
if metadata_json:
# Create metadata element with CDATA section
metadata_element = f""" <metadata>
<![CDATA[
{metadata_json}
]]>
</metadata>
"""
# Insert metadata after opening svg tag using regex with a replacement function
def replacement(match):
# match.group(1) contains the captured <svg> tag
return match.group(1) + '\n' + metadata_element
# Apply the substitution
svg_content = re.sub(r'(<svg[^>]*>)', replacement, svg_content, flags=re.UNICODE)
# Write the modified SVG to file
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
svg_file.write(svg_content.encode('utf-8'))
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "images": results } }
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
"ImageFromBatch": ImageFromBatch,
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
"SaveSVGNode": SaveSVGNode,
}

View File

@ -2,6 +2,10 @@ import nodes
import folder_paths
import os
from comfy.comfy_types import IO
from comfy_api.input_impl import VideoFromFile
def normalize_path(path):
return path.replace('\\', '/')
@ -21,8 +25,8 @@ class Load3D():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info")
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video")
FUNCTION = "process"
EXPERIMENTAL = True
@ -41,7 +45,14 @@ class Load3D():
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info']
video = None
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
video = VideoFromFile(recording_video_path)
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video
class Load3DAnimation():
@classmethod
@ -59,8 +70,8 @@ class Load3DAnimation():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info")
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
FUNCTION = "process"
EXPERIMENTAL = True
@ -77,7 +88,14 @@ class Load3DAnimation():
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
return output_image, output_mask, model_file, normal_image, image['camera_info']
video = None
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
video = VideoFromFile(recording_video_path)
return output_image, output_mask, model_file, normal_image, image['camera_info'], video
class Preview3D():
@classmethod

Some files were not shown because too many files have changed in this diff Show More