mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-08 15:17:14 +00:00
Merge branch 'master' into worksplit-multigpu
This commit is contained in:
commit
9726eac475
@ -63,7 +63,12 @@ except:
|
||||
print("checking out master branch") # noqa: T201
|
||||
branch = repo.lookup_branch('master')
|
||||
if branch is None:
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
try:
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
except:
|
||||
print("pulling.") # noqa: T201
|
||||
pull(repo)
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
repo.checkout(ref)
|
||||
branch = repo.lookup_branch('master')
|
||||
if branch is None:
|
||||
|
12
.github/workflows/stable-release.yml
vendored
12
.github/workflows/stable-release.yml
vendored
@ -12,7 +12,7 @@ on:
|
||||
description: 'CUDA version'
|
||||
required: true
|
||||
type: string
|
||||
default: "126"
|
||||
default: "128"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
@ -22,7 +22,7 @@ on:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "10"
|
||||
|
||||
|
||||
jobs:
|
||||
@ -36,7 +36,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.git_tag }}
|
||||
fetch-depth: 0
|
||||
fetch-depth: 150
|
||||
persist-credentials: false
|
||||
- uses: actions/cache/restore@v4
|
||||
id: cache
|
||||
@ -70,7 +70,7 @@ jobs:
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
mkdir ComfyUI_windows_portable
|
||||
mv python_embeded ComfyUI_windows_portable
|
||||
@ -85,12 +85,14 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
||||
|
||||
python_embeded/python.exe -s ./update/update.py ComfyUI/
|
||||
|
||||
ls
|
||||
|
||||
- name: Upload binaries to release
|
||||
|
2
.github/workflows/test-launch.yml
vendored
2
.github/workflows/test-launch.yml
vendored
@ -17,7 +17,7 @@ jobs:
|
||||
path: "ComfyUI"
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.9'
|
||||
python-version: '3.10'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
56
.github/workflows/update-api-stubs.yml
vendored
Normal file
56
.github/workflows/update-api-stubs.yml
vendored
Normal file
@ -0,0 +1,56 @@
|
||||
name: Generate Pydantic Stubs from api.comfy.org
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
generate-models:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install 'datamodel-code-generator[http]'
|
||||
npm install @redocly/cli
|
||||
|
||||
- name: Download OpenAPI spec
|
||||
run: |
|
||||
curl -o openapi.yaml https://api.comfy.org/openapi
|
||||
|
||||
- name: Filter OpenAPI spec with Redocly
|
||||
run: |
|
||||
npx @redocly/cli bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
|
||||
|
||||
- name: Generate API models
|
||||
run: |
|
||||
datamodel-codegen --use-subclass-enum --input filtered-openapi.yaml --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel
|
||||
|
||||
- name: Check for changes
|
||||
id: git-check
|
||||
run: |
|
||||
git diff --exit-code comfy_api_nodes/apis || echo "changes=true" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Create Pull Request
|
||||
if: steps.git-check.outputs.changes == 'true'
|
||||
uses: peter-evans/create-pull-request@v5
|
||||
with:
|
||||
commit-message: 'chore: update API models from OpenAPI spec'
|
||||
title: 'Update API models from api.comfy.org'
|
||||
body: |
|
||||
This PR updates the API models based on the latest api.comfy.org OpenAPI specification.
|
||||
|
||||
Generated automatically by the a Github workflow.
|
||||
branch: update-api-stubs
|
||||
delete-branch: true
|
||||
base: master
|
@ -17,7 +17,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "126"
|
||||
default: "128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@ -29,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "10"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
@ -56,7 +56,7 @@ jobs:
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
mkdir ComfyUI_windows_portable_nightly_pytorch
|
||||
mv python_embeded ComfyUI_windows_portable_nightly_pytorch
|
||||
|
12
.github/workflows/windows_release_package.yml
vendored
12
.github/workflows/windows_release_package.yml
vendored
@ -7,7 +7,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "126"
|
||||
default: "128"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "10"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@ -50,7 +50,7 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
fetch-depth: 150
|
||||
persist-credentials: false
|
||||
- shell: bash
|
||||
run: |
|
||||
@ -67,7 +67,7 @@ jobs:
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||
cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
mkdir ComfyUI_windows_portable
|
||||
mv python_embeded ComfyUI_windows_portable
|
||||
@ -82,12 +82,14 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
||||
|
||||
python_embeded/python.exe -s ./update/update.py ComfyUI/
|
||||
|
||||
ls
|
||||
|
||||
- name: Upload binaries to release
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -21,3 +21,6 @@ venv/
|
||||
*.log
|
||||
web_custom_versions/
|
||||
.DS_Store
|
||||
openapi.yaml
|
||||
filtered-openapi.yaml
|
||||
uv.lock
|
||||
|
28
README.md
28
README.md
@ -49,7 +49,6 @@ Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon,
|
||||
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
||||
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Image Models
|
||||
@ -70,9 +69,11 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- Audio Models
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- 3D Models
|
||||
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||
@ -99,6 +100,23 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
|
||||
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
|
||||
## Release Process
|
||||
|
||||
ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0)
|
||||
- Serves as the foundation for the desktop release
|
||||
|
||||
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
||||
- Builds a new release using the latest stable core version
|
||||
- Version numbers match the core release (e.g., Desktop v1.7.0 uses Core v1.7.0)
|
||||
|
||||
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
|
||||
- Weekly frontend updates are merged into the core repository
|
||||
- Features are frozen for the upcoming core release
|
||||
- Development continues for the next release cycle
|
||||
|
||||
## Shortcuts
|
||||
|
||||
| Keybind | Explanation |
|
||||
@ -149,8 +167,6 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
If you have a 50 series Blackwell card like a 5090 or 5080 see [this discussion thread](https://github.com/comfyanonymous/ComfyUI/discussions/6643)
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
|
||||
See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
@ -216,9 +232,9 @@ Additional discussion and help can be found [here](https://github.com/comfyanony
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126```
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128```
|
||||
|
||||
This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements.
|
||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
|
||||
|
||||
|
@ -93,16 +93,20 @@ class CustomNodeManager:
|
||||
|
||||
def add_routes(self, routes, webapp, loadedModules):
|
||||
|
||||
example_workflow_folder_names = ["example_workflows", "example", "examples", "workflow", "workflows"]
|
||||
|
||||
@routes.get("/workflow_templates")
|
||||
async def get_workflow_templates(request):
|
||||
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
|
||||
files = [
|
||||
file
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes")
|
||||
for file in glob.glob(
|
||||
os.path.join(folder, "*/example_workflows/*.json")
|
||||
)
|
||||
]
|
||||
|
||||
files = []
|
||||
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||
for folder_name in example_workflow_folder_names:
|
||||
pattern = os.path.join(folder, f"*/{folder_name}/*.json")
|
||||
matched_files = glob.glob(pattern)
|
||||
files.extend(matched_files)
|
||||
|
||||
workflow_templates_dict = (
|
||||
{}
|
||||
) # custom_nodes folder name -> example workflow names
|
||||
@ -118,15 +122,22 @@ class CustomNodeManager:
|
||||
|
||||
# Serve workflow templates from custom nodes.
|
||||
for module_name, module_dir in loadedModules:
|
||||
workflows_dir = os.path.join(module_dir, "example_workflows")
|
||||
if os.path.exists(workflows_dir):
|
||||
webapp.add_routes(
|
||||
[
|
||||
web.static(
|
||||
"/api/workflow_templates/" + module_name, workflows_dir
|
||||
)
|
||||
]
|
||||
)
|
||||
for folder_name in example_workflow_folder_names:
|
||||
workflows_dir = os.path.join(module_dir, folder_name)
|
||||
|
||||
if os.path.exists(workflows_dir):
|
||||
if folder_name != "example_workflows":
|
||||
logging.debug(
|
||||
"Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'",
|
||||
folder_name, module_name)
|
||||
|
||||
webapp.add_routes(
|
||||
[
|
||||
web.static(
|
||||
"/api/workflow_templates/" + module_name, workflows_dir
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@routes.get("/i18n")
|
||||
async def get_i18n(request):
|
||||
|
@ -197,6 +197,112 @@ class UserManager():
|
||||
|
||||
return web.json_response(results)
|
||||
|
||||
@routes.get("/v2/userdata")
|
||||
async def list_userdata_v2(request):
|
||||
"""
|
||||
List files and directories in a user's data directory.
|
||||
|
||||
This endpoint provides a structured listing of contents within a specified
|
||||
subdirectory of the user's data storage.
|
||||
|
||||
Query Parameters:
|
||||
- path (optional): The relative path within the user's data directory
|
||||
to list. Defaults to the root ('').
|
||||
|
||||
Returns:
|
||||
- 400: If the requested path is invalid, outside the user's data directory, or is not a directory.
|
||||
- 404: If the requested path does not exist.
|
||||
- 403: If the user is invalid.
|
||||
- 500: If there is an error reading the directory contents.
|
||||
- 200: JSON response containing a list of file and directory objects.
|
||||
Each object includes:
|
||||
- name: The name of the file or directory.
|
||||
- type: 'file' or 'directory'.
|
||||
- path: The relative path from the user's data root.
|
||||
- size (for files): The size in bytes.
|
||||
- modified (for files): The last modified timestamp (Unix epoch).
|
||||
"""
|
||||
requested_rel_path = request.rel_url.query.get('path', '')
|
||||
|
||||
# URL-decode the path parameter
|
||||
try:
|
||||
requested_rel_path = parse.unquote(requested_rel_path)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}")
|
||||
return web.Response(status=400, text="Invalid characters in path parameter")
|
||||
|
||||
|
||||
# Check user validity and get the absolute path for the requested directory
|
||||
try:
|
||||
base_user_path = self.get_request_user_filepath(request, None, create_dir=False)
|
||||
|
||||
if requested_rel_path:
|
||||
target_abs_path = self.get_request_user_filepath(request, requested_rel_path, create_dir=False)
|
||||
else:
|
||||
target_abs_path = base_user_path
|
||||
|
||||
except KeyError as e:
|
||||
# Invalid user detected by get_request_user_id inside get_request_user_filepath
|
||||
logging.warning(f"Access denied for user: {e}")
|
||||
return web.Response(status=403, text="Invalid user specified in request")
|
||||
|
||||
|
||||
if not target_abs_path:
|
||||
# Path traversal or other issue detected by get_request_user_filepath
|
||||
return web.Response(status=400, text="Invalid path requested")
|
||||
|
||||
# Handle cases where the user directory or target path doesn't exist
|
||||
if not os.path.exists(target_abs_path):
|
||||
# Check if it's the base user directory that's missing (new user case)
|
||||
if target_abs_path == base_user_path:
|
||||
# It's okay if the base user directory doesn't exist yet, return empty list
|
||||
return web.json_response([])
|
||||
else:
|
||||
# A specific subdirectory was requested but doesn't exist
|
||||
return web.Response(status=404, text="Requested path not found")
|
||||
|
||||
if not os.path.isdir(target_abs_path):
|
||||
return web.Response(status=400, text="Requested path is not a directory")
|
||||
|
||||
results = []
|
||||
try:
|
||||
for root, dirs, files in os.walk(target_abs_path, topdown=True):
|
||||
# Process directories
|
||||
for dir_name in dirs:
|
||||
dir_path = os.path.join(root, dir_name)
|
||||
rel_path = os.path.relpath(dir_path, base_user_path).replace(os.sep, '/')
|
||||
results.append({
|
||||
"name": dir_name,
|
||||
"path": rel_path,
|
||||
"type": "directory"
|
||||
})
|
||||
|
||||
# Process files
|
||||
for file_name in files:
|
||||
file_path = os.path.join(root, file_name)
|
||||
rel_path = os.path.relpath(file_path, base_user_path).replace(os.sep, '/')
|
||||
entry_info = {
|
||||
"name": file_name,
|
||||
"path": rel_path,
|
||||
"type": "file"
|
||||
}
|
||||
try:
|
||||
stats = os.stat(file_path) # Use os.stat for potentially better performance with os.walk
|
||||
entry_info["size"] = stats.st_size
|
||||
entry_info["modified"] = stats.st_mtime
|
||||
except OSError as stat_error:
|
||||
logging.warning(f"Could not stat file {file_path}: {stat_error}")
|
||||
pass # Include file with available info
|
||||
results.append(entry_info)
|
||||
except OSError as e:
|
||||
logging.error(f"Error listing directory {target_abs_path}: {e}")
|
||||
return web.Response(status=500, text="Error reading directory contents")
|
||||
|
||||
# Sort results alphabetically, directories first then files
|
||||
results.sort(key=lambda x: (x['type'] != 'directory', x['name'].lower()))
|
||||
|
||||
return web.json_response(results)
|
||||
|
||||
def get_user_data_path(request, check_exists = False, param = "file"):
|
||||
file = request.match_info.get(param, None)
|
||||
if not file:
|
||||
|
@ -128,6 +128,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
|
||||
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||
|
||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
|
||||
@ -141,12 +142,15 @@ class PerformanceFeature(enum.Enum):
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
@ -191,6 +195,13 @@ parser.add_argument("--user-directory", type=is_valid_directory, default=None, h
|
||||
|
||||
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
|
||||
|
||||
parser.add_argument(
|
||||
"--comfy-api-base",
|
||||
type=str,
|
||||
default="https://api.comfy.org",
|
||||
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
||||
)
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
else:
|
||||
|
@ -18,6 +18,7 @@ class Output:
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Comfy-specific type hinting"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Literal, TypedDict, Optional
|
||||
from typing_extensions import NotRequired
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
@ -48,6 +48,7 @@ class IO(StrEnum):
|
||||
FACE_ANALYSIS = "FACE_ANALYSIS"
|
||||
BBOX = "BBOX"
|
||||
SEGS = "SEGS"
|
||||
VIDEO = "VIDEO"
|
||||
|
||||
ANY = "*"
|
||||
"""Always matches any type, but at a price.
|
||||
@ -120,6 +121,10 @@ class InputTypeOptions(TypedDict):
|
||||
Available from frontend v1.17.5
|
||||
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548
|
||||
"""
|
||||
widgetType: NotRequired[str]
|
||||
"""Specifies a type to be used for widget initialization if different from the input type.
|
||||
Available from frontend v1.18.0
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/3550"""
|
||||
# class InputTypeNumber(InputTypeOptions):
|
||||
# default: float | int
|
||||
min: NotRequired[float]
|
||||
@ -229,6 +234,8 @@ class ComfyNodeABC(ABC):
|
||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||
DEPRECATED: bool
|
||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||
API_NODE: Optional[bool]
|
||||
"""Flags a node as an API node."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@ -267,7 +274,7 @@ class ComfyNodeABC(ABC):
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
||||
"""
|
||||
OUTPUT_IS_LIST: tuple[bool]
|
||||
OUTPUT_IS_LIST: tuple[bool, ...]
|
||||
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
|
||||
|
||||
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
|
||||
@ -286,7 +293,7 @@ class ComfyNodeABC(ABC):
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
|
||||
"""
|
||||
|
||||
RETURN_TYPES: tuple[IO]
|
||||
RETURN_TYPES: tuple[IO, ...]
|
||||
"""A tuple representing the outputs of this node.
|
||||
|
||||
Usage::
|
||||
@ -295,12 +302,12 @@ class ComfyNodeABC(ABC):
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types
|
||||
"""
|
||||
RETURN_NAMES: tuple[str]
|
||||
RETURN_NAMES: tuple[str, ...]
|
||||
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
|
||||
|
||||
Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names
|
||||
"""
|
||||
OUTPUT_TOOLTIPS: tuple[str]
|
||||
OUTPUT_TOOLTIPS: tuple[str, ...]
|
||||
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
|
||||
FUNCTION: str
|
||||
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
|
||||
|
@ -116,7 +116,7 @@ class Dino2Embeddings(torch.nn.Module):
|
||||
def forward(self, pixel_values):
|
||||
x = self.patch_embeddings(pixel_values)
|
||||
# TODO: mask_token?
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||
return x
|
||||
|
||||
|
@ -1277,6 +1277,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
phi1_fn = lambda t: torch.expm1(t) / t
|
||||
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
|
||||
|
||||
old_sigma_down = None
|
||||
old_denoised = None
|
||||
uncond_denoised = None
|
||||
def post_cfg_function(args):
|
||||
@ -1304,9 +1305,9 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Second order multistep method in https://arxiv.org/pdf/2308.02157
|
||||
t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigma_down), t_fn(sigmas[i - 1])
|
||||
t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
|
||||
h = t_next - t
|
||||
c2 = (t_prev - t) / h
|
||||
c2 = (t_prev - t_old) / h
|
||||
|
||||
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
|
||||
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
|
||||
@ -1326,6 +1327,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
old_denoised = uncond_denoised
|
||||
else:
|
||||
old_denoised = denoised
|
||||
old_sigma_down = sigma_down
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
@ -1345,28 +1347,52 @@ def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, cal
|
||||
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
|
||||
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
old_d = None
|
||||
|
||||
uncond_denoised = None
|
||||
def post_cfg_function(args):
|
||||
nonlocal uncond_denoised
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
if cfg_pp:
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
if cfg_pp:
|
||||
d = to_d(x, sigmas[i], uncond_denoised)
|
||||
else:
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
if i == 0:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
if cfg_pp:
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
else:
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Gradient estimation
|
||||
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
||||
x = x + d_bar * dt
|
||||
if cfg_pp:
|
||||
d_bar = (ge_gamma - 1) * (d - old_d)
|
||||
x = denoised + d * sigmas[i + 1] + d_bar * dt
|
||||
else:
|
||||
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
|
||||
x = x + d_bar * dt
|
||||
old_d = d
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
||||
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
|
||||
"""
|
||||
|
@ -466,3 +466,7 @@ class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
scale_factor = 1.0188137142395404
|
||||
|
||||
class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
761
comfy/ldm/ace/attention.py
Normal file
761
comfy/ldm/ace/attention.py
Normal file
@ -0,0 +1,761 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/attention.py
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Tuple, Union, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
kv_heads: Optional[int] = None,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
added_proj_bias: Optional[bool] = True,
|
||||
out_bias: bool = True,
|
||||
scale_qk: bool = True,
|
||||
only_cross_attention: bool = False,
|
||||
eps: float = 1e-5,
|
||||
rescale_output_factor: float = 1.0,
|
||||
residual_connection: bool = False,
|
||||
processor=None,
|
||||
out_dim: int = None,
|
||||
out_context_dim: int = None,
|
||||
context_pre_only=None,
|
||||
pre_only=False,
|
||||
elementwise_affine: bool = True,
|
||||
is_causal: bool = False,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
||||
self.query_dim = query_dim
|
||||
self.use_bias = bias
|
||||
self.is_cross_attention = cross_attention_dim is not None
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
self.residual_connection = residual_connection
|
||||
self.dropout = dropout
|
||||
self.fused_projections = False
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
self.pre_only = pre_only
|
||||
self.is_causal = is_causal
|
||||
|
||||
self.scale_qk = scale_qk
|
||||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
# for slice_size > 0 the attention score computation
|
||||
# is split across the batch axis to save memory
|
||||
# You can set slice_size with `set_attention_slice`
|
||||
self.sliceable_head_dim = heads
|
||||
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
||||
raise ValueError(
|
||||
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
||||
)
|
||||
|
||||
self.group_norm = None
|
||||
self.spatial_norm = None
|
||||
|
||||
self.norm_q = None
|
||||
self.norm_k = None
|
||||
|
||||
self.norm_cross = None
|
||||
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
if not self.only_cross_attention:
|
||||
# only relevant for the `AddedKVProcessor` classes
|
||||
self.to_k = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_k = None
|
||||
self.to_v = None
|
||||
|
||||
self.added_proj_bias = added_proj_bias
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
||||
self.add_v_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
||||
if self.context_pre_only is not None:
|
||||
self.add_q_proj = operations.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.add_q_proj = None
|
||||
self.add_k_proj = None
|
||||
self.add_v_proj = None
|
||||
|
||||
if not self.pre_only:
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
else:
|
||||
self.to_out = None
|
||||
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_add_out = None
|
||||
|
||||
self.norm_added_q = None
|
||||
self.norm_added_k = None
|
||||
self.processor = processor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**cross_attention_kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CustomLiteLAProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
|
||||
|
||||
def __init__(self):
|
||||
self.kernel_func = nn.ReLU(inplace=False)
|
||||
self.eps = 1e-15
|
||||
self.pad_val = 1.0
|
||||
|
||||
def apply_rotary_emb(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states_len = hidden_states.shape[1]
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
if encoder_hidden_states is not None:
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
dtype = hidden_states.dtype
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
# `context` projections.
|
||||
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
|
||||
if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
# attention
|
||||
if not attn.is_cross_attention:
|
||||
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
||||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
||||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
||||
else:
|
||||
query = hidden_states
|
||||
key = encoder_hidden_states
|
||||
value = encoder_hidden_states
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
||||
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
|
||||
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
||||
|
||||
# RoPE需要 [B, H, S, D] 输入
|
||||
# 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
|
||||
query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S])
|
||||
|
||||
# Apply query and key normalization if needed
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if rotary_freqs_cis is not None:
|
||||
query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
||||
if not attn.is_cross_attention:
|
||||
key = self.apply_rotary_emb(key, rotary_freqs_cis)
|
||||
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
|
||||
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
|
||||
|
||||
# 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
|
||||
query = query.permute(0, 1, 3, 2) # [B, H, D, S]
|
||||
|
||||
if attention_mask is not None:
|
||||
# attention_mask: [B, S] -> [B, 1, S, 1]
|
||||
attention_mask = attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S, 1]
|
||||
query = query * attention_mask.permute(0, 1, 3, 2) # [B, H, S, D] * [B, 1, S, 1]
|
||||
if not attn.is_cross_attention:
|
||||
key = key * attention_mask # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
|
||||
value = value * attention_mask.permute(0, 1, 3, 2) # 如果 value 是 [B, h, D, S],那么需调整mask以匹配S维度
|
||||
|
||||
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
|
||||
encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S_enc, 1]
|
||||
# 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
|
||||
key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1]
|
||||
value = value * encoder_attention_mask.permute(0, 1, 3, 2) # [B, h, D, S_enc] * [B, 1, 1, S_enc]
|
||||
|
||||
query = self.kernel_func(query)
|
||||
key = self.kernel_func(key)
|
||||
|
||||
query, key, value = query.float(), key.float(), value.float()
|
||||
|
||||
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
|
||||
|
||||
vk = torch.matmul(value, key)
|
||||
|
||||
hidden_states = torch.matmul(vk, query)
|
||||
|
||||
if hidden_states.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.float()
|
||||
|
||||
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
||||
|
||||
hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
|
||||
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = encoder_hidden_states.to(dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
if encoder_hidden_states is not None and not attn.is_cross_attention and has_encoder_hidden_state_proj:
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : hidden_states_len],
|
||||
hidden_states[:, hidden_states_len:],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if encoder_hidden_states is not None and not attn.context_pre_only and not attn.is_cross_attention and hasattr(attn, "to_add_out"):
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if encoder_hidden_states is not None and context_input_ndim == 4:
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if torch.get_autocast_gpu_dtype() == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CustomerAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def apply_rotary_emb(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
residual = hidden_states
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if rotary_freqs_cis is not None:
|
||||
query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
||||
if not attn.is_cross_attention:
|
||||
key = self.apply_rotary_emb(key, rotary_freqs_cis)
|
||||
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
|
||||
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
|
||||
|
||||
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
|
||||
# attention_mask: N x S1
|
||||
# encoder_attention_mask: N x S2
|
||||
# cross attention 整合attention_mask和encoder_attention_mask
|
||||
combined_mask = attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
|
||||
attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
|
||||
attention_mask = attention_mask[:, None, :, :].expand(-1, attn.heads, -1, -1).to(query.dtype)
|
||||
|
||||
elif not attn.is_cross_attention and attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = optimized_attention(
|
||||
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
|
||||
).to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
|
||||
"""Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
|
||||
if isinstance(x, (list, tuple)):
|
||||
return list(x)
|
||||
return [x for _ in range(repeat_time)]
|
||||
|
||||
|
||||
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
|
||||
"""Return tuple with min_len by repeating element at idx_repeat."""
|
||||
# convert to list first
|
||||
x = val2list(x)
|
||||
|
||||
# repeat elements if necessary
|
||||
if len(x) > 0:
|
||||
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
|
||||
|
||||
return tuple(x)
|
||||
|
||||
|
||||
def t2i_modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
|
||||
if isinstance(kernel_size, tuple):
|
||||
return tuple([get_same_padding(ks) for ks in kernel_size])
|
||||
else:
|
||||
assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
|
||||
return kernel_size // 2
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
padding: Union[int, None] = None,
|
||||
use_bias=False,
|
||||
norm=None,
|
||||
act=None,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
if padding is None:
|
||||
padding = get_same_padding(kernel_size)
|
||||
padding *= dilation
|
||||
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
self.padding = padding
|
||||
self.use_bias = use_bias
|
||||
|
||||
self.conv = operations.Conv1d(
|
||||
in_dim,
|
||||
out_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=use_bias,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
if norm is not None:
|
||||
self.norm = operations.RMSNorm(out_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
else:
|
||||
self.norm = None
|
||||
if act is not None:
|
||||
self.act = nn.SiLU(inplace=True)
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv(x)
|
||||
if self.norm:
|
||||
x = self.norm(x)
|
||||
if self.act:
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class GLUMBConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
out_feature=None,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding: Union[int, None] = None,
|
||||
use_bias=False,
|
||||
norm=(None, None, None),
|
||||
act=("silu", "silu", None),
|
||||
dilation=1,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
out_feature = out_feature or in_features
|
||||
super().__init__()
|
||||
use_bias = val2tuple(use_bias, 3)
|
||||
norm = val2tuple(norm, 3)
|
||||
act = val2tuple(act, 3)
|
||||
|
||||
self.glu_act = nn.SiLU(inplace=False)
|
||||
self.inverted_conv = ConvLayer(
|
||||
in_features,
|
||||
hidden_features * 2,
|
||||
1,
|
||||
use_bias=use_bias[0],
|
||||
norm=norm[0],
|
||||
act=act[0],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.depth_conv = ConvLayer(
|
||||
hidden_features * 2,
|
||||
hidden_features * 2,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
groups=hidden_features * 2,
|
||||
padding=padding,
|
||||
use_bias=use_bias[1],
|
||||
norm=norm[1],
|
||||
act=None,
|
||||
dilation=dilation,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.point_conv = ConvLayer(
|
||||
hidden_features,
|
||||
out_feature,
|
||||
1,
|
||||
use_bias=use_bias[2],
|
||||
norm=norm[2],
|
||||
act=act[2],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.transpose(1, 2)
|
||||
x = self.inverted_conv(x)
|
||||
x = self.depth_conv(x)
|
||||
|
||||
x, gate = torch.chunk(x, 2, dim=1)
|
||||
gate = self.glu_act(gate)
|
||||
x = x * gate
|
||||
|
||||
x = self.point_conv(x)
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LinearTransformerBlock(nn.Module):
|
||||
"""
|
||||
A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
use_adaln_single=True,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=None,
|
||||
context_pre_only=False,
|
||||
mlp_ratio=4.0,
|
||||
add_cross_attention=False,
|
||||
add_cross_attention_dim=None,
|
||||
qk_norm=None,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = operations.RMSNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
qk_norm=qk_norm,
|
||||
processor=CustomLiteLAProcessor2_0(),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.add_cross_attention = add_cross_attention
|
||||
self.context_pre_only = context_pre_only
|
||||
|
||||
if add_cross_attention and add_cross_attention_dim is not None:
|
||||
self.cross_attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=add_cross_attention_dim,
|
||||
added_kv_proj_dim=add_cross_attention_dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=context_pre_only,
|
||||
bias=True,
|
||||
qk_norm=qk_norm,
|
||||
processor=CustomerAttnProcessor2_0(),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.norm2 = operations.RMSNorm(dim, 1e-06, elementwise_affine=False)
|
||||
|
||||
self.ff = GLUMBConv(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
use_bias=(True, True, False),
|
||||
norm=(None, None, None),
|
||||
act=("silu", "silu", None),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.use_adaln_single = use_adaln_single
|
||||
if use_adaln_single:
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: torch.FloatTensor = None,
|
||||
encoder_attention_mask: torch.FloatTensor = None,
|
||||
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
||||
temb: torch.FloatTensor = None,
|
||||
):
|
||||
|
||||
N = hidden_states.shape[0]
|
||||
|
||||
# step 1: AdaLN single
|
||||
if self.use_adaln_single:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
comfy.model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
if self.use_adaln_single:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
# step 2: attention
|
||||
if not self.add_cross_attention:
|
||||
attn_output, encoder_hidden_states = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
)
|
||||
else:
|
||||
attn_output, _ = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=None,
|
||||
)
|
||||
|
||||
if self.use_adaln_single:
|
||||
attn_output = gate_msa * attn_output
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
if self.add_cross_attention:
|
||||
attn_output = self.cross_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# step 3: add norm
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
if self.use_adaln_single:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
# step 4: feed forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
if self.use_adaln_single:
|
||||
ff_output = gate_mlp * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
|
||||
return hidden_states
|
1067
comfy/ldm/ace/lyric_encoder.py
Normal file
1067
comfy/ldm/ace/lyric_encoder.py
Normal file
File diff suppressed because it is too large
Load Diff
385
comfy/ldm/ace/model.py
Normal file
385
comfy/ldm/ace/model.py
Normal file
@ -0,0 +1,385 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/ace_step_transformer.py
|
||||
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, List, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from .attention import LinearTransformerBlock, t2i_modulate
|
||||
from .lyric_encoder import ConformerEncoder as LyricEncoder
|
||||
|
||||
|
||||
def cross_norm(hidden_states, controlnet_input):
|
||||
# input N x T x c
|
||||
mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
|
||||
mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
|
||||
controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
|
||||
return controlnet_input
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
|
||||
class Qwen2RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, dtype=None, device=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=device).float() / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
class T2IFinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of Sana.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(2, hidden_size, dtype=dtype, device=device))
|
||||
self.out_channels = out_channels
|
||||
self.patch_size = patch_size
|
||||
|
||||
def unpatchfy(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
width: int,
|
||||
):
|
||||
# 4 unpatchify
|
||||
new_height, new_width = 1, hidden_states.size(1)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
|
||||
).contiguous()
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
|
||||
).contiguous()
|
||||
if width > new_width:
|
||||
output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
|
||||
elif width < new_width:
|
||||
output = output[:, :, :, :width]
|
||||
return output
|
||||
|
||||
def forward(self, x, t, output_length):
|
||||
shift, scale = (comfy.model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
|
||||
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
# unpatchify
|
||||
output = self.unpatchfy(x, output_length)
|
||||
return output
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
height=16,
|
||||
width=4096,
|
||||
patch_size=(16, 1),
|
||||
in_channels=8,
|
||||
embed_dim=1152,
|
||||
bias=True,
|
||||
dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
patch_size_h, patch_size_w = patch_size
|
||||
self.early_conv_layers = nn.Sequential(
|
||||
operations.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias, dtype=dtype, device=device),
|
||||
operations.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True, dtype=dtype, device=device),
|
||||
operations.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias, dtype=dtype, device=device)
|
||||
)
|
||||
self.patch_size = patch_size
|
||||
self.height, self.width = height // patch_size_h, width // patch_size_w
|
||||
self.base_size = self.width
|
||||
|
||||
def forward(self, latent):
|
||||
# early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
|
||||
latent = self.early_conv_layers(latent)
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
return latent
|
||||
|
||||
|
||||
class ACEStepTransformer2DModel(nn.Module):
|
||||
# _supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: Optional[int] = 8,
|
||||
num_layers: int = 28,
|
||||
inner_dim: int = 1536,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 24,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_channels: int = 8,
|
||||
max_position: int = 32768,
|
||||
rope_theta: float = 1000000.0,
|
||||
speaker_embedding_dim: int = 512,
|
||||
text_embedding_dim: int = 768,
|
||||
ssl_encoder_depths: List[int] = [9, 9],
|
||||
ssl_names: List[str] = ["mert", "m-hubert"],
|
||||
ssl_latent_dims: List[int] = [1024, 768],
|
||||
lyric_encoder_vocab_size: int = 6681,
|
||||
lyric_hidden_size: int = 1024,
|
||||
patch_size: List[int] = [16, 1],
|
||||
max_height: int = 16,
|
||||
max_width: int = 4096,
|
||||
audio_model=None,
|
||||
dtype=None, device=None, operations=None
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dtype = dtype
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.inner_dim = inner_dim
|
||||
self.out_channels = out_channels
|
||||
self.max_position = max_position
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.rotary_emb = Qwen2RotaryEmbedding(
|
||||
dim=self.attention_head_dim,
|
||||
max_position_embeddings=self.max_position,
|
||||
base=self.rope_theta,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.num_layers = num_layers
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
LinearTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
add_cross_attention=True,
|
||||
add_cross_attention_dim=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.t_block = nn.Sequential(nn.SiLU(), operations.Linear(self.inner_dim, 6 * self.inner_dim, bias=True, dtype=dtype, device=device))
|
||||
|
||||
# speaker
|
||||
self.speaker_embedder = operations.Linear(speaker_embedding_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
# genre
|
||||
self.genre_embedder = operations.Linear(text_embedding_dim, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
# lyric
|
||||
self.lyric_embs = operations.Embedding(lyric_encoder_vocab_size, lyric_hidden_size, dtype=dtype, device=device)
|
||||
self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0, dtype=dtype, device=device, operations=operations)
|
||||
self.lyric_proj = operations.Linear(lyric_hidden_size, self.inner_dim, dtype=dtype, device=device)
|
||||
|
||||
projector_dim = 2 * self.inner_dim
|
||||
|
||||
self.projectors = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
operations.Linear(self.inner_dim, projector_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(projector_dim, projector_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(projector_dim, ssl_dim, dtype=dtype, device=device),
|
||||
) for ssl_dim in ssl_latent_dims
|
||||
])
|
||||
|
||||
self.proj_in = PatchEmbed(
|
||||
height=max_height,
|
||||
width=max_width,
|
||||
patch_size=patch_size,
|
||||
embed_dim=self.inner_dim,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_lyric_encoder(
|
||||
self,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
out_dtype=None,
|
||||
):
|
||||
# N x T x D
|
||||
lyric_embs = self.lyric_embs(lyric_token_idx, out_dtype=out_dtype)
|
||||
prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
|
||||
prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
|
||||
return prompt_prenet_out
|
||||
|
||||
def encode(
|
||||
self,
|
||||
encoder_text_hidden_states: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
lyrics_strength=1.0,
|
||||
):
|
||||
|
||||
bs = encoder_text_hidden_states.shape[0]
|
||||
device = encoder_text_hidden_states.device
|
||||
|
||||
# speaker embedding
|
||||
encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
|
||||
|
||||
# genre embedding
|
||||
encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
|
||||
|
||||
# lyric
|
||||
encoder_lyric_hidden_states = self.forward_lyric_encoder(
|
||||
lyric_token_idx=lyric_token_idx,
|
||||
lyric_mask=lyric_mask,
|
||||
out_dtype=encoder_text_hidden_states.dtype,
|
||||
)
|
||||
|
||||
encoder_lyric_hidden_states *= lyrics_strength
|
||||
|
||||
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
|
||||
|
||||
encoder_hidden_mask = None
|
||||
if text_attention_mask is not None:
|
||||
speaker_mask = torch.ones(bs, 1, device=device)
|
||||
encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
|
||||
|
||||
return encoder_hidden_states, encoder_hidden_mask
|
||||
|
||||
def decode(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_mask: torch.Tensor,
|
||||
timestep: Optional[torch.Tensor],
|
||||
output_length: int = 0,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
):
|
||||
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
||||
temb = self.t_block(embedded_timestep)
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# controlnet logic
|
||||
if block_controlnet_hidden_states is not None:
|
||||
control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
|
||||
hidden_states = hidden_states + control_condi * controlnet_scale
|
||||
|
||||
# inner_hidden_states = []
|
||||
|
||||
rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
|
||||
encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_hidden_mask,
|
||||
rotary_freqs_cis=rotary_freqs_cis,
|
||||
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
||||
temb=temb,
|
||||
)
|
||||
|
||||
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
attention_mask=None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
text_attention_mask: Optional[torch.LongTensor] = None,
|
||||
speaker_embeds: Optional[torch.FloatTensor] = None,
|
||||
lyric_token_idx: Optional[torch.LongTensor] = None,
|
||||
lyric_mask: Optional[torch.LongTensor] = None,
|
||||
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
||||
lyrics_strength=1.0,
|
||||
**kwargs
|
||||
):
|
||||
hidden_states = x
|
||||
encoder_text_hidden_states = context
|
||||
encoder_hidden_states, encoder_hidden_mask = self.encode(
|
||||
encoder_text_hidden_states=encoder_text_hidden_states,
|
||||
text_attention_mask=text_attention_mask,
|
||||
speaker_embeds=speaker_embeds,
|
||||
lyric_token_idx=lyric_token_idx,
|
||||
lyric_mask=lyric_mask,
|
||||
lyrics_strength=lyrics_strength,
|
||||
)
|
||||
|
||||
output_length = hidden_states.shape[-1]
|
||||
|
||||
output = self.decode(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_mask=encoder_hidden_mask,
|
||||
timestep=timestep,
|
||||
output_length=output_length,
|
||||
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
||||
controlnet_scale=controlnet_scale,
|
||||
)
|
||||
|
||||
return output
|
644
comfy/ldm/ace/vae/autoencoder_dc.py
Normal file
644
comfy/ldm/ace/vae/autoencoder_dc.py
Normal file
@ -0,0 +1,644 @@
|
||||
# Rewritten from diffusers
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple, Union
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class RMSNorm(ops.RMSNorm):
|
||||
def __init__(self, dim, eps=1e-5, elementwise_affine=True, bias=False):
|
||||
super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
if elementwise_affine:
|
||||
self.bias = nn.Parameter(torch.empty(dim)) if bias else None
|
||||
|
||||
def forward(self, x):
|
||||
x = super().forward(x)
|
||||
if self.elementwise_affine:
|
||||
if self.bias is not None:
|
||||
x = x + comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device)
|
||||
return x
|
||||
|
||||
|
||||
def get_normalization(norm_type, num_features, num_groups=32, eps=1e-5):
|
||||
if norm_type == "batch_norm":
|
||||
return nn.BatchNorm2d(num_features)
|
||||
elif norm_type == "group_norm":
|
||||
return ops.GroupNorm(num_groups, num_features)
|
||||
elif norm_type == "layer_norm":
|
||||
return ops.LayerNorm(num_features)
|
||||
elif norm_type == "rms_norm":
|
||||
return RMSNorm(num_features, eps=eps, elementwise_affine=True, bias=True)
|
||||
else:
|
||||
raise ValueError(f"Unknown normalization type: {norm_type}")
|
||||
|
||||
|
||||
def get_activation(activation_type):
|
||||
if activation_type == "relu":
|
||||
return nn.ReLU()
|
||||
elif activation_type == "relu6":
|
||||
return nn.ReLU6()
|
||||
elif activation_type == "silu":
|
||||
return nn.SiLU()
|
||||
elif activation_type == "leaky_relu":
|
||||
return nn.LeakyReLU(0.2)
|
||||
else:
|
||||
raise ValueError(f"Unknown activation type: {activation_type}")
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
norm_type: str = "batch_norm",
|
||||
act_fn: str = "relu6",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
|
||||
self.conv1 = ops.Conv2d(in_channels, in_channels, 3, 1, 1)
|
||||
self.conv2 = ops.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
|
||||
self.norm = get_normalization(norm_type, out_channels)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
||||
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
else:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states + residual
|
||||
|
||||
class SanaMultiscaleAttentionProjection(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_attention_heads: int,
|
||||
kernel_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
channels = 3 * in_channels
|
||||
self.proj_in = ops.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=channels,
|
||||
bias=False,
|
||||
)
|
||||
self.proj_out = ops.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class SanaMultiscaleLinearAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_attention_heads: int = None,
|
||||
attention_head_dim: int = 8,
|
||||
mult: float = 1.0,
|
||||
norm_type: str = "batch_norm",
|
||||
kernel_sizes: tuple = (5,),
|
||||
eps: float = 1e-15,
|
||||
residual_connection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
num_attention_heads = (
|
||||
int(in_channels // attention_head_dim * mult)
|
||||
if num_attention_heads is None
|
||||
else num_attention_heads
|
||||
)
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.to_q = ops.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_k = ops.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_v = ops.Linear(in_channels, inner_dim, bias=False)
|
||||
|
||||
self.to_qkv_multiscale = nn.ModuleList()
|
||||
for kernel_size in kernel_sizes:
|
||||
self.to_qkv_multiscale.append(
|
||||
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
|
||||
)
|
||||
|
||||
self.nonlinearity = nn.ReLU()
|
||||
self.to_out = ops.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
|
||||
self.norm_out = get_normalization(norm_type, out_channels)
|
||||
|
||||
def apply_linear_attention(self, query, key, value):
|
||||
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1)
|
||||
scores = torch.matmul(value, key.transpose(-1, -2))
|
||||
hidden_states = torch.matmul(scores, query)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
||||
return hidden_states
|
||||
|
||||
def apply_quadratic_attention(self, query, key, value):
|
||||
scores = torch.matmul(key.transpose(-1, -2), query)
|
||||
scores = scores.to(dtype=torch.float32)
|
||||
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
|
||||
hidden_states = torch.matmul(value, scores.to(value.dtype))
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states):
|
||||
height, width = hidden_states.shape[-2:]
|
||||
if height * width > self.attention_head_dim:
|
||||
use_linear_attention = True
|
||||
else:
|
||||
use_linear_attention = False
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
batch_size, _, height, width = list(hidden_states.size())
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
hidden_states = hidden_states.movedim(1, -1)
|
||||
query = self.to_q(hidden_states)
|
||||
key = self.to_k(hidden_states)
|
||||
value = self.to_v(hidden_states)
|
||||
hidden_states = torch.cat([query, key, value], dim=3)
|
||||
hidden_states = hidden_states.movedim(-1, 1)
|
||||
|
||||
multi_scale_qkv = [hidden_states]
|
||||
for block in self.to_qkv_multiscale:
|
||||
multi_scale_qkv.append(block(hidden_states))
|
||||
|
||||
hidden_states = torch.cat(multi_scale_qkv, dim=1)
|
||||
|
||||
if use_linear_attention:
|
||||
# for linear attention upcast hidden_states to float32
|
||||
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch_size, -1, 3 * self.attention_head_dim, height * width)
|
||||
|
||||
query, key, value = hidden_states.chunk(3, dim=2)
|
||||
query = self.nonlinearity(query)
|
||||
key = self.nonlinearity(key)
|
||||
|
||||
if use_linear_attention:
|
||||
hidden_states = self.apply_linear_attention(query, key, value)
|
||||
hidden_states = hidden_states.to(dtype=original_dtype)
|
||||
else:
|
||||
hidden_states = self.apply_quadratic_attention(query, key, value)
|
||||
|
||||
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
|
||||
hidden_states = self.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
else:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
if self.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EfficientViTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
mult: float = 1.0,
|
||||
attention_head_dim: int = 32,
|
||||
qkv_multiscales: tuple = (5,),
|
||||
norm_type: str = "batch_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.attn = SanaMultiscaleLinearAttention(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
mult=mult,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type,
|
||||
kernel_sizes=qkv_multiscales,
|
||||
residual_connection=True,
|
||||
)
|
||||
|
||||
self.conv_out = GLUMBConv(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
norm_type="rms_norm",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.attn(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class GLUMBConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
expand_ratio: float = 4,
|
||||
norm_type: str = None,
|
||||
residual_connection: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_channels = int(expand_ratio * in_channels)
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
self.nonlinearity = nn.SiLU()
|
||||
self.conv_inverted = ops.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
|
||||
self.conv_depth = ops.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
|
||||
self.conv_point = ops.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
|
||||
|
||||
self.norm = None
|
||||
if norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.residual_connection:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.conv_inverted(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv_depth(hidden_states)
|
||||
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
|
||||
hidden_states = hidden_states * self.nonlinearity(gate)
|
||||
|
||||
hidden_states = self.conv_point(hidden_states)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
||||
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def get_block(
|
||||
block_type: str,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
attention_head_dim: int,
|
||||
norm_type: str,
|
||||
act_fn: str,
|
||||
qkv_mutliscales: tuple = (),
|
||||
):
|
||||
if block_type == "ResBlock":
|
||||
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
||||
elif block_type == "EfficientViTBlock":
|
||||
block = EfficientViTBlock(
|
||||
in_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type,
|
||||
qkv_multiscales=qkv_mutliscales
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Block with {block_type=} is not supported.")
|
||||
|
||||
return block
|
||||
|
||||
|
||||
class DCDownBlock2d(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.downsample = downsample
|
||||
self.factor = 2
|
||||
self.stride = 1 if downsample else 2
|
||||
self.group_size = in_channels * self.factor**2 // out_channels
|
||||
self.shortcut = shortcut
|
||||
|
||||
out_ratio = self.factor**2
|
||||
if downsample:
|
||||
assert out_channels % out_ratio == 0
|
||||
out_channels = out_channels // out_ratio
|
||||
|
||||
self.conv = ops.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=self.stride,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv(hidden_states)
|
||||
if self.downsample:
|
||||
x = F.pixel_unshuffle(x, self.factor)
|
||||
|
||||
if self.shortcut:
|
||||
y = F.pixel_unshuffle(hidden_states, self.factor)
|
||||
y = y.unflatten(1, (-1, self.group_size))
|
||||
y = y.mean(dim=2)
|
||||
hidden_states = x + y
|
||||
else:
|
||||
hidden_states = x
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DCUpBlock2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
interpolate: bool = False,
|
||||
shortcut: bool = True,
|
||||
interpolation_mode: str = "nearest",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.interpolate = interpolate
|
||||
self.interpolation_mode = interpolation_mode
|
||||
self.shortcut = shortcut
|
||||
self.factor = 2
|
||||
self.repeats = out_channels * self.factor**2 // in_channels
|
||||
|
||||
out_ratio = self.factor**2
|
||||
if not interpolate:
|
||||
out_channels = out_channels * out_ratio
|
||||
|
||||
self.conv = ops.Conv2d(in_channels, out_channels, 3, 1, 1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.interpolate:
|
||||
x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = self.conv(hidden_states)
|
||||
x = F.pixel_shuffle(x, self.factor)
|
||||
|
||||
if self.shortcut:
|
||||
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
|
||||
y = F.pixel_shuffle(y, self.factor)
|
||||
hidden_states = x + y
|
||||
else:
|
||||
hidden_states = x
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: str or tuple = "ResBlock",
|
||||
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
|
||||
downsample_block_type: str = "pixel_unshuffle",
|
||||
out_shortcut: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_blocks = len(block_out_channels)
|
||||
|
||||
if isinstance(block_type, str):
|
||||
block_type = (block_type,) * num_blocks
|
||||
|
||||
if layers_per_block[0] > 0:
|
||||
self.conv_in = ops.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
else:
|
||||
self.conv_in = DCDownBlock2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
|
||||
downsample=downsample_block_type == "pixel_unshuffle",
|
||||
shortcut=False,
|
||||
)
|
||||
|
||||
down_blocks = []
|
||||
for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)):
|
||||
down_block_list = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
block = get_block(
|
||||
block_type[i],
|
||||
out_channel,
|
||||
out_channel,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type="rms_norm",
|
||||
act_fn="silu",
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
)
|
||||
down_block_list.append(block)
|
||||
|
||||
if i < num_blocks - 1 and num_layers > 0:
|
||||
downsample_block = DCDownBlock2d(
|
||||
in_channels=out_channel,
|
||||
out_channels=block_out_channels[i + 1],
|
||||
downsample=downsample_block_type == "pixel_unshuffle",
|
||||
shortcut=True,
|
||||
)
|
||||
down_block_list.append(downsample_block)
|
||||
|
||||
down_blocks.append(nn.Sequential(*down_block_list))
|
||||
|
||||
self.down_blocks = nn.ModuleList(down_blocks)
|
||||
|
||||
self.conv_out = ops.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1)
|
||||
|
||||
self.out_shortcut = out_shortcut
|
||||
if out_shortcut:
|
||||
self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
|
||||
if self.out_shortcut:
|
||||
x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size))
|
||||
x = x.mean(dim=2)
|
||||
hidden_states = self.conv_out(hidden_states) + x
|
||||
else:
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: str or tuple = "ResBlock",
|
||||
block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
|
||||
norm_type: str or tuple = "rms_norm",
|
||||
act_fn: str or tuple = "silu",
|
||||
upsample_block_type: str = "pixel_shuffle",
|
||||
in_shortcut: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_blocks = len(block_out_channels)
|
||||
|
||||
if isinstance(block_type, str):
|
||||
block_type = (block_type,) * num_blocks
|
||||
if isinstance(norm_type, str):
|
||||
norm_type = (norm_type,) * num_blocks
|
||||
if isinstance(act_fn, str):
|
||||
act_fn = (act_fn,) * num_blocks
|
||||
|
||||
self.conv_in = ops.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1)
|
||||
|
||||
self.in_shortcut = in_shortcut
|
||||
if in_shortcut:
|
||||
self.in_shortcut_repeats = block_out_channels[-1] // latent_channels
|
||||
|
||||
up_blocks = []
|
||||
for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
|
||||
up_block_list = []
|
||||
|
||||
if i < num_blocks - 1 and num_layers > 0:
|
||||
upsample_block = DCUpBlock2d(
|
||||
block_out_channels[i + 1],
|
||||
out_channel,
|
||||
interpolate=upsample_block_type == "interpolate",
|
||||
shortcut=True,
|
||||
)
|
||||
up_block_list.append(upsample_block)
|
||||
|
||||
for _ in range(num_layers):
|
||||
block = get_block(
|
||||
block_type[i],
|
||||
out_channel,
|
||||
out_channel,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type[i],
|
||||
act_fn=act_fn[i],
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
)
|
||||
up_block_list.append(block)
|
||||
|
||||
up_blocks.insert(0, nn.Sequential(*up_block_list))
|
||||
|
||||
self.up_blocks = nn.ModuleList(up_blocks)
|
||||
|
||||
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
|
||||
|
||||
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
|
||||
self.conv_act = nn.ReLU()
|
||||
self.conv_out = None
|
||||
|
||||
if layers_per_block[0] > 0:
|
||||
self.conv_out = ops.Conv2d(channels, in_channels, 3, 1, 1)
|
||||
else:
|
||||
self.conv_out = DCUpBlock2d(
|
||||
channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.in_shortcut:
|
||||
x = hidden_states.repeat_interleave(
|
||||
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
|
||||
)
|
||||
hidden_states = self.conv_in(hidden_states) + x
|
||||
else:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for up_block in reversed(self.up_blocks):
|
||||
hidden_states = up_block(hidden_states)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderDC(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 2,
|
||||
latent_channels: int = 8,
|
||||
attention_head_dim: int = 32,
|
||||
encoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
|
||||
decoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
|
||||
encoder_layers_per_block: Tuple[int] = (2, 2, 3, 3),
|
||||
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3),
|
||||
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
|
||||
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
|
||||
upsample_block_type: str = "interpolate",
|
||||
downsample_block_type: str = "Conv",
|
||||
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
|
||||
decoder_act_fns: Union[str, Tuple[str]] = "silu",
|
||||
scaling_factor: float = 0.41407,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
latent_channels=latent_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
block_type=encoder_block_types,
|
||||
block_out_channels=encoder_block_out_channels,
|
||||
layers_per_block=encoder_layers_per_block,
|
||||
qkv_multiscales=encoder_qkv_multiscales,
|
||||
downsample_block_type=downsample_block_type,
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
in_channels=in_channels,
|
||||
latent_channels=latent_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
block_type=decoder_block_types,
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
layers_per_block=decoder_layers_per_block,
|
||||
qkv_multiscales=decoder_qkv_multiscales,
|
||||
norm_type=decoder_norm_types,
|
||||
act_fn=decoder_act_fns,
|
||||
upsample_block_type=upsample_block_type,
|
||||
)
|
||||
|
||||
self.scaling_factor = scaling_factor
|
||||
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
|
||||
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Internal encoding function."""
|
||||
encoded = self.encoder(x)
|
||||
return encoded * self.scaling_factor
|
||||
|
||||
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
# Scale the latents back
|
||||
z = z / self.scaling_factor
|
||||
decoded = self.decoder(z)
|
||||
return decoded
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
z = self.encode(x)
|
||||
return self.decode(z)
|
||||
|
109
comfy/ldm/ace/vae/music_dcae_pipeline.py
Normal file
109
comfy/ldm/ace/vae/music_dcae_pipeline.py
Normal file
@ -0,0 +1,109 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py
|
||||
import torch
|
||||
from .autoencoder_dc import AutoencoderDC
|
||||
import logging
|
||||
try:
|
||||
import torchaudio
|
||||
except:
|
||||
logging.warning("torchaudio missing, ACE model will be broken")
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
from .music_vocoder import ADaMoSHiFiGANV1
|
||||
|
||||
|
||||
class MusicDCAE(torch.nn.Module):
|
||||
def __init__(self, source_sample_rate=None, dcae_config={}, vocoder_config={}):
|
||||
super(MusicDCAE, self).__init__()
|
||||
|
||||
self.dcae = AutoencoderDC(**dcae_config)
|
||||
self.vocoder = ADaMoSHiFiGANV1(**vocoder_config)
|
||||
|
||||
if source_sample_rate is None:
|
||||
self.source_sample_rate = 48000
|
||||
else:
|
||||
self.source_sample_rate = source_sample_rate
|
||||
|
||||
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Normalize(0.5, 0.5),
|
||||
])
|
||||
self.min_mel_value = -11.0
|
||||
self.max_mel_value = 3.0
|
||||
self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
|
||||
self.mel_chunk_size = 1024
|
||||
self.time_dimention_multiple = 8
|
||||
self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
|
||||
self.scale_factor = 0.1786
|
||||
self.shift_factor = -1.9091
|
||||
|
||||
def load_audio(self, audio_path):
|
||||
audio, sr = torchaudio.load(audio_path)
|
||||
return audio, sr
|
||||
|
||||
def forward_mel(self, audios):
|
||||
mels = []
|
||||
for i in range(len(audios)):
|
||||
image = self.vocoder.mel_transform(audios[i])
|
||||
mels.append(image)
|
||||
mels = torch.stack(mels)
|
||||
return mels
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, audios, audio_lengths=None, sr=None):
|
||||
if audio_lengths is None:
|
||||
audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
|
||||
audio_lengths = audio_lengths.to(audios.device)
|
||||
|
||||
if sr is None:
|
||||
sr = self.source_sample_rate
|
||||
|
||||
if sr != 44100:
|
||||
audios = torchaudio.functional.resample(audios, sr, 44100)
|
||||
|
||||
max_audio_len = audios.shape[-1]
|
||||
if max_audio_len % (8 * 512) != 0:
|
||||
audios = torch.nn.functional.pad(audios, (0, 8 * 512 - max_audio_len % (8 * 512)))
|
||||
|
||||
mels = self.forward_mel(audios)
|
||||
mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
|
||||
mels = self.transform(mels)
|
||||
latents = []
|
||||
for mel in mels:
|
||||
latent = self.dcae.encoder(mel.unsqueeze(0))
|
||||
latents.append(latent)
|
||||
latents = torch.cat(latents, dim=0)
|
||||
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
|
||||
latents = (latents - self.shift_factor) * self.scale_factor
|
||||
return latents
|
||||
# return latents, latent_lengths
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, latents, audio_lengths=None, sr=None):
|
||||
latents = latents / self.scale_factor + self.shift_factor
|
||||
|
||||
pred_wavs = []
|
||||
|
||||
for latent in latents:
|
||||
mels = self.dcae.decoder(latent.unsqueeze(0))
|
||||
mels = mels * 0.5 + 0.5
|
||||
mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
|
||||
wav = self.vocoder.decode(mels[0]).squeeze(1)
|
||||
|
||||
if sr is not None:
|
||||
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
|
||||
wav = torchaudio.functional.resample(wav, 44100, sr)
|
||||
# wav = resampler(wav)
|
||||
else:
|
||||
sr = 44100
|
||||
pred_wavs.append(wav)
|
||||
|
||||
if audio_lengths is not None:
|
||||
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
|
||||
return torch.stack(pred_wavs)
|
||||
# return sr, pred_wavs
|
||||
|
||||
def forward(self, audios, audio_lengths=None, sr=None):
|
||||
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
|
||||
sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
|
||||
return sr, pred_wavs, latents, latent_lengths
|
113
comfy/ldm/ace/vae/music_log_mel.py
Executable file
113
comfy/ldm/ace/vae/music_log_mel.py
Executable file
@ -0,0 +1,113 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
import logging
|
||||
try:
|
||||
from torchaudio.transforms import MelScale
|
||||
except:
|
||||
logging.warning("torchaudio missing, ACE model will be broken")
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
class LinearSpectrogram(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft=2048,
|
||||
win_length=2048,
|
||||
hop_length=512,
|
||||
center=False,
|
||||
mode="pow2_sqrt",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
self.mode = mode
|
||||
|
||||
self.register_buffer("window", torch.hann_window(win_length))
|
||||
|
||||
def forward(self, y: Tensor) -> Tensor:
|
||||
if y.ndim == 3:
|
||||
y = y.squeeze(1)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(
|
||||
(self.win_length - self.hop_length) // 2,
|
||||
(self.win_length - self.hop_length + 1) // 2,
|
||||
),
|
||||
mode="reflect",
|
||||
).squeeze(1)
|
||||
dtype = y.dtype
|
||||
spec = torch.stft(
|
||||
y.float(),
|
||||
self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=comfy.model_management.cast_to(self.window, dtype=torch.float32, device=y.device),
|
||||
center=self.center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
spec = torch.view_as_real(spec)
|
||||
|
||||
if self.mode == "pow2_sqrt":
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
spec = spec.to(dtype)
|
||||
return spec
|
||||
|
||||
|
||||
class LogMelSpectrogram(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate=44100,
|
||||
n_fft=2048,
|
||||
win_length=2048,
|
||||
hop_length=512,
|
||||
n_mels=128,
|
||||
center=False,
|
||||
f_min=0.0,
|
||||
f_max=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
self.n_mels = n_mels
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max or sample_rate // 2
|
||||
|
||||
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
|
||||
self.mel_scale = MelScale(
|
||||
self.n_mels,
|
||||
self.sample_rate,
|
||||
self.f_min,
|
||||
self.f_max,
|
||||
self.n_fft // 2 + 1,
|
||||
"slaney",
|
||||
"slaney",
|
||||
)
|
||||
|
||||
def compress(self, x: Tensor) -> Tensor:
|
||||
return torch.log(torch.clamp(x, min=1e-5))
|
||||
|
||||
def decompress(self, x: Tensor) -> Tensor:
|
||||
return torch.exp(x)
|
||||
|
||||
def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
|
||||
linear = self.spectrogram(x)
|
||||
x = self.mel_scale(linear)
|
||||
x = self.compress(x)
|
||||
# print(x.shape)
|
||||
if return_linear:
|
||||
return x, self.compress(linear)
|
||||
|
||||
return x
|
538
comfy/ldm/ace/vae/music_vocoder.py
Executable file
538
comfy/ldm/ace/vae/music_vocoder.py
Executable file
@ -0,0 +1,538 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_vocoder.py
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from functools import partial
|
||||
from math import prod
|
||||
from typing import Callable, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
|
||||
|
||||
from .music_log_mel import LogMelSpectrogram
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
def drop_path(
|
||||
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
||||
):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (
|
||||
x.ndim - 1
|
||||
) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0 and scale_by_keep:
|
||||
random_tensor.div_(keep_prob)
|
||||
return x * random_tensor
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
|
||||
|
||||
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.scale_by_keep = scale_by_keep
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
||||
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
||||
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
||||
with shape (batch_size, channels, height, width).
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||
self.eps = eps
|
||||
self.data_format = data_format
|
||||
if self.data_format not in ["channels_last", "channels_first"]:
|
||||
raise NotImplementedError
|
||||
self.normalized_shape = (normalized_shape,)
|
||||
|
||||
def forward(self, x):
|
||||
if self.data_format == "channels_last":
|
||||
return F.layer_norm(
|
||||
x, self.normalized_shape, comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device), comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device), self.eps
|
||||
)
|
||||
elif self.data_format == "channels_first":
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = comfy.model_management.cast_to(self.weight[:, None], dtype=x.dtype, device=x.device) * x + comfy.model_management.cast_to(self.bias[:, None], dtype=x.dtype, device=x.device)
|
||||
return x
|
||||
|
||||
|
||||
class ConvNeXtBlock(nn.Module):
|
||||
r"""ConvNeXt Block. There are two equivalent implementations:
|
||||
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
||||
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
||||
We use (2) as we find it slightly faster in PyTorch
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
drop_path (float): Stochastic depth rate. Default: 0.0
|
||||
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
|
||||
kernel_size (int): Kernel size for depthwise conv. Default: 7.
|
||||
dilation (int): Dilation for depthwise conv. Default: 1.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
drop_path: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-6,
|
||||
mlp_ratio: float = 4.0,
|
||||
kernel_size: int = 7,
|
||||
dilation: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dwconv = ops.Conv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=int(dilation * (kernel_size - 1) / 2),
|
||||
groups=dim,
|
||||
) # depthwise conv
|
||||
self.norm = LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = ops.Linear(
|
||||
dim, int(mlp_ratio * dim)
|
||||
) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = nn.GELU()
|
||||
self.pwconv2 = ops.Linear(int(mlp_ratio * dim), dim)
|
||||
self.gamma = (
|
||||
nn.Parameter(torch.empty((dim)), requires_grad=False)
|
||||
if layer_scale_init_value > 0
|
||||
else None
|
||||
)
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
def forward(self, x, apply_residual: bool = True):
|
||||
input = x
|
||||
|
||||
x = self.dwconv(x)
|
||||
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
|
||||
x = self.norm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pwconv2(x)
|
||||
|
||||
if self.gamma is not None:
|
||||
x = comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) * x
|
||||
|
||||
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
|
||||
x = self.drop_path(x)
|
||||
|
||||
if apply_residual:
|
||||
x = input + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ParallelConvNeXtBlock(nn.Module):
|
||||
def __init__(self, kernel_sizes: List[int], *args, **kwargs):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
|
||||
for kernel_size in kernel_sizes
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack(
|
||||
[block(x, apply_residual=False) for block in self.blocks] + [x],
|
||||
dim=1,
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
class ConvNeXtEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channels=3,
|
||||
depths=[3, 3, 9, 3],
|
||||
dims=[96, 192, 384, 768],
|
||||
drop_path_rate=0.0,
|
||||
layer_scale_init_value=1e-6,
|
||||
kernel_sizes: Tuple[int] = (7,),
|
||||
):
|
||||
super().__init__()
|
||||
assert len(depths) == len(dims)
|
||||
|
||||
self.channel_layers = nn.ModuleList()
|
||||
stem = nn.Sequential(
|
||||
ops.Conv1d(
|
||||
input_channels,
|
||||
dims[0],
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
padding_mode="replicate",
|
||||
),
|
||||
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
||||
)
|
||||
self.channel_layers.append(stem)
|
||||
|
||||
for i in range(len(depths) - 1):
|
||||
mid_layer = nn.Sequential(
|
||||
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
||||
ops.Conv1d(dims[i], dims[i + 1], kernel_size=1),
|
||||
)
|
||||
self.channel_layers.append(mid_layer)
|
||||
|
||||
block_fn = (
|
||||
partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
|
||||
if len(kernel_sizes) == 1
|
||||
else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
|
||||
)
|
||||
|
||||
self.stages = nn.ModuleList()
|
||||
drop_path_rates = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
]
|
||||
|
||||
cur = 0
|
||||
for i in range(len(depths)):
|
||||
stage = nn.Sequential(
|
||||
*[
|
||||
block_fn(
|
||||
dim=dims[i],
|
||||
drop_path=drop_path_rates[cur + j],
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
)
|
||||
for j in range(depths[i])
|
||||
]
|
||||
)
|
||||
self.stages.append(stage)
|
||||
cur += depths[i]
|
||||
|
||||
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
for channel_layer, stage in zip(self.channel_layers, self.stages):
|
||||
x = channel_layer(x)
|
||||
x = stage(x)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return (kernel_size * dilation - dilation) // 2
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super().__init__()
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.silu(x)
|
||||
xt = c1(xt)
|
||||
xt = F.silu(xt)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for conv in self.convs1:
|
||||
remove_weight_norm(conv)
|
||||
for conv in self.convs2:
|
||||
remove_weight_norm(conv)
|
||||
|
||||
|
||||
class HiFiGANGenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
hop_length: int = 512,
|
||||
upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
|
||||
upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
|
||||
resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
|
||||
resblock_dilation_sizes: Tuple[Tuple[int]] = (
|
||||
(1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
||||
num_mels: int = 128,
|
||||
upsample_initial_channel: int = 512,
|
||||
use_template: bool = True,
|
||||
pre_conv_kernel_size: int = 7,
|
||||
post_conv_kernel_size: int = 7,
|
||||
post_activation: Callable = partial(nn.SiLU, inplace=True),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
prod(upsample_rates) == hop_length
|
||||
), f"hop_length must be {prod(upsample_rates)}"
|
||||
|
||||
self.conv_pre = torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
num_mels,
|
||||
upsample_initial_channel,
|
||||
pre_conv_kernel_size,
|
||||
1,
|
||||
padding=get_padding(pre_conv_kernel_size),
|
||||
)
|
||||
)
|
||||
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
|
||||
self.noise_convs = nn.ModuleList()
|
||||
self.use_template = use_template
|
||||
self.ups = nn.ModuleList()
|
||||
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||
self.ups.append(
|
||||
torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.ConvTranspose1d(
|
||||
upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not use_template:
|
||||
continue
|
||||
|
||||
if i + 1 < len(upsample_rates):
|
||||
stride_f0 = np.prod(upsample_rates[i + 1:])
|
||||
self.noise_convs.append(
|
||||
ops.Conv1d(
|
||||
1,
|
||||
c_cur,
|
||||
kernel_size=stride_f0 * 2,
|
||||
stride=stride_f0,
|
||||
padding=stride_f0 // 2,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.noise_convs.append(ops.Conv1d(1, c_cur, kernel_size=1))
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
||||
self.resblocks.append(ResBlock1(ch, k, d))
|
||||
|
||||
self.activation_post = post_activation()
|
||||
self.conv_post = torch.nn.utils.parametrizations.weight_norm(
|
||||
ops.Conv1d(
|
||||
ch,
|
||||
1,
|
||||
post_conv_kernel_size,
|
||||
1,
|
||||
padding=get_padding(post_conv_kernel_size),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x, template=None):
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.silu(x, inplace=True)
|
||||
x = self.ups[i](x)
|
||||
|
||||
if self.use_template:
|
||||
x = x + self.noise_convs[i](template)
|
||||
|
||||
xs = None
|
||||
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
|
||||
x = xs / self.num_kernels
|
||||
|
||||
x = self.activation_post(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for up in self.ups:
|
||||
remove_weight_norm(up)
|
||||
for block in self.resblocks:
|
||||
block.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
|
||||
class ADaMoSHiFiGANV1(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channels: int = 128,
|
||||
depths: List[int] = [3, 3, 9, 3],
|
||||
dims: List[int] = [128, 256, 384, 512],
|
||||
drop_path_rate: float = 0.0,
|
||||
kernel_sizes: Tuple[int] = (7,),
|
||||
upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
|
||||
upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
|
||||
resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
|
||||
resblock_dilation_sizes: Tuple[Tuple[int]] = (
|
||||
(1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
||||
num_mels: int = 512,
|
||||
upsample_initial_channel: int = 1024,
|
||||
use_template: bool = False,
|
||||
pre_conv_kernel_size: int = 13,
|
||||
post_conv_kernel_size: int = 13,
|
||||
sampling_rate: int = 44100,
|
||||
n_fft: int = 2048,
|
||||
win_length: int = 2048,
|
||||
hop_length: int = 512,
|
||||
f_min: int = 40,
|
||||
f_max: int = 16000,
|
||||
n_mels: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.backbone = ConvNeXtEncoder(
|
||||
input_channels=input_channels,
|
||||
depths=depths,
|
||||
dims=dims,
|
||||
drop_path_rate=drop_path_rate,
|
||||
kernel_sizes=kernel_sizes,
|
||||
)
|
||||
|
||||
self.head = HiFiGANGenerator(
|
||||
hop_length=hop_length,
|
||||
upsample_rates=upsample_rates,
|
||||
upsample_kernel_sizes=upsample_kernel_sizes,
|
||||
resblock_kernel_sizes=resblock_kernel_sizes,
|
||||
resblock_dilation_sizes=resblock_dilation_sizes,
|
||||
num_mels=num_mels,
|
||||
upsample_initial_channel=upsample_initial_channel,
|
||||
use_template=use_template,
|
||||
pre_conv_kernel_size=pre_conv_kernel_size,
|
||||
post_conv_kernel_size=post_conv_kernel_size,
|
||||
)
|
||||
self.sampling_rate = sampling_rate
|
||||
self.mel_transform = LogMelSpectrogram(
|
||||
sample_rate=sampling_rate,
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
f_min=f_min,
|
||||
f_max=f_max,
|
||||
n_mels=n_mels,
|
||||
)
|
||||
self.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, mel):
|
||||
y = self.backbone(mel)
|
||||
y = self.head(y)
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, x):
|
||||
return self.mel_transform(x)
|
||||
|
||||
def forward(self, mel):
|
||||
y = self.backbone(mel)
|
||||
y = self.head(y)
|
||||
return y
|
@ -75,16 +75,10 @@ class SnakeBeta(nn.Module):
|
||||
return x
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||
if activation == "elu":
|
||||
|
183
comfy/ldm/chroma/layers.py
Normal file
183
comfy/ldm/chroma/layers.py
Normal file
@ -0,0 +1,183 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.math import attention
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
QKNorm,
|
||||
SelfAttention,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class ChromaModulationOut(ModulationOut):
|
||||
@classmethod
|
||||
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
|
||||
return cls(
|
||||
shift=tensor[:, offset : offset + 1, :],
|
||||
scale=tensor[:, offset + 1 : offset + 2, :],
|
||||
gate=tensor[:, offset + 2 : offset + 3, :],
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
class Approximator(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.in_proj(x)
|
||||
|
||||
for layer, norms in zip(self.layers, self.norms):
|
||||
x = x + layer(norms(x))
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
||||
mod = vec
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = vec
|
||||
shift = shift.squeeze(1)
|
||||
scale = scale.squeeze(1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
271
comfy/ldm/chroma/model.py
Normal file
271
comfy/ldm/chroma/model.py
Normal file
@ -0,0 +1,271 @@
|
||||
#Original code can be found on: https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
EmbedND,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
LastLayer,
|
||||
SingleStreamBlock,
|
||||
Approximator,
|
||||
ChromaModulationOut,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChromaParams:
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
in_dim: int
|
||||
out_dim: int
|
||||
hidden_dim: int
|
||||
n_layers: int
|
||||
|
||||
|
||||
|
||||
|
||||
class Chroma(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
params = ChromaParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.in_dim = params.in_dim
|
||||
self.out_dim = params.out_dim
|
||||
self.hidden_dim = params.hidden_dim
|
||||
self.n_layers = params.n_layers
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
# set as nn identity for now, will overwrite it later.
|
||||
self.distilled_guidance_layer = Approximator(
|
||||
in_dim=self.in_dim,
|
||||
hidden_dim=self.hidden_dim,
|
||||
out_dim=self.out_dim,
|
||||
n_layers=self.n_layers,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.skip_mmdit = []
|
||||
self.skip_dit = []
|
||||
self.lite = False
|
||||
|
||||
def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0):
|
||||
# This function slices up the modulations tensor which has the following layout:
|
||||
# single : num_single_blocks * 3 elements
|
||||
# double_img : num_double_blocks * 6 elements
|
||||
# double_txt : num_double_blocks * 6 elements
|
||||
# final : 2 elements
|
||||
if block_type == "final":
|
||||
return (tensor[:, -2:-1, :], tensor[:, -1:, :])
|
||||
single_block_count = self.params.depth_single_blocks
|
||||
double_block_count = self.params.depth
|
||||
offset = 3 * idx
|
||||
if block_type == "single":
|
||||
return ChromaModulationOut.from_offset(tensor, offset)
|
||||
# Double block modulations are 6 elements so we double 3 * idx.
|
||||
offset *= 2
|
||||
if block_type in {"double_img", "double_txt"}:
|
||||
# Advance past the single block modulations.
|
||||
offset += 3 * single_block_count
|
||||
if block_type == "double_txt":
|
||||
# Advance past the double block img modulations.
|
||||
offset += 6 * double_block_count
|
||||
return (
|
||||
ChromaModulationOut.from_offset(tensor, offset),
|
||||
ChromaModulationOut.from_offset(tensor, offset + 3),
|
||||
)
|
||||
raise ValueError("Bad block_type")
|
||||
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control = None,
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
|
||||
# distilled vector guidance
|
||||
mod_index_length = 344
|
||||
distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype)
|
||||
# guidance = guidance *
|
||||
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
|
||||
|
||||
# get all modulation index
|
||||
modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(img.device, img.dtype)
|
||||
# we need to broadcast the modulation index here so each batch has all of the index
|
||||
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
|
||||
# and we need to broadcast timestep and guidance along too
|
||||
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype)
|
||||
# then and only then we could concatenate it together
|
||||
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype)
|
||||
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec)
|
||||
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if i not in self.skip_mmdit:
|
||||
double_mod = (
|
||||
self.get_modulations(mod_vectors, "double_img", idx=i),
|
||||
self.get_modulations(mod_vectors, "double_txt", idx=i),
|
||||
)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"],
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": double_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img,
|
||||
txt=txt,
|
||||
vec=double_mod,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if i not in self.skip_dit:
|
||||
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": single_mod,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
final_mod = self.get_modulations(mod_vectors, "final")
|
||||
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
@ -23,7 +23,6 @@ from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
@ -37,11 +36,11 @@ def apply_rotary_pos_emb(
|
||||
return t_out
|
||||
|
||||
|
||||
def get_normalization(name: str, channels: int, weight_args={}):
|
||||
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
||||
if name == "I":
|
||||
return nn.Identity()
|
||||
elif name == "R":
|
||||
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
||||
return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
||||
else:
|
||||
raise ValueError(f"Normalization {name} not found")
|
||||
|
||||
@ -120,15 +119,15 @@ class Attention(nn.Module):
|
||||
|
||||
self.to_q = nn.Sequential(
|
||||
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||
get_normalization(qkv_norm[0], norm_dim),
|
||||
get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations),
|
||||
)
|
||||
self.to_k = nn.Sequential(
|
||||
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||
get_normalization(qkv_norm[1], norm_dim),
|
||||
get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations),
|
||||
)
|
||||
self.to_v = nn.Sequential(
|
||||
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||
get_normalization(qkv_norm[2], norm_dim),
|
||||
get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations),
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
|
@ -27,8 +27,6 @@ from torchvision import transforms
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||
|
||||
from .blocks import (
|
||||
FinalLayer,
|
||||
GeneralDITTransformerBlock,
|
||||
@ -195,7 +193,7 @@ class GeneralDIT(nn.Module):
|
||||
|
||||
if self.affline_emb_norm:
|
||||
logging.debug("Building affine embedding normalization layer")
|
||||
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
|
||||
self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
|
||||
else:
|
||||
self.affline_norm = nn.Identity()
|
||||
|
||||
|
@ -13,7 +13,6 @@ from comfy.ldm.modules.attention import optimized_attention
|
||||
from .layers import (
|
||||
FeedForward,
|
||||
PatchEmbed,
|
||||
RMSNorm,
|
||||
TimestepEmbedder,
|
||||
)
|
||||
|
||||
@ -90,10 +89,10 @@ class AsymmetricAttention(nn.Module):
|
||||
|
||||
# Query and key normalization for stability.
|
||||
assert qk_norm
|
||||
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
||||
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||
|
||||
# Output layers. y features go back down from dim_x -> dim_y.
|
||||
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
|
||||
|
@ -151,14 +151,3 @@ class PatchEmbed(nn.Module):
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
|
@ -699,10 +699,13 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_llama3=None,
|
||||
image_cond=None,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
) -> torch.Tensor:
|
||||
bs, c, h, w = x.shape
|
||||
if image_cond is not None:
|
||||
x = torch.cat([x, image_cond], dim=-1)
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
timesteps = t
|
||||
pooled_embeds = y
|
||||
|
@ -3,7 +3,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed
|
||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from torch.utils import checkpoint
|
||||
|
||||
@ -51,7 +51,7 @@ class HunYuanDiTBlock(nn.Module):
|
||||
if norm_type == "layer":
|
||||
norm_layer = operations.LayerNorm
|
||||
elif norm_type == "rms":
|
||||
norm_layer = RMSNorm
|
||||
norm_layer = operations.RMSNorm
|
||||
else:
|
||||
raise ValueError(f"Unknown norm_type: {norm_type}")
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.ldm.modules.attention
|
||||
from comfy.ldm.genmo.joint_model.layers import RMSNorm
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
import math
|
||||
@ -262,8 +261,8 @@ class CrossAttention(nn.Module):
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
@ -64,8 +64,8 @@ class JointAttention(nn.Module):
|
||||
)
|
||||
|
||||
if qk_norm:
|
||||
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
||||
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
||||
self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
else:
|
||||
self.q_norm = self.k_norm = nn.Identity()
|
||||
|
||||
@ -242,11 +242,11 @@ class JointTransformerBlock(nn.Module):
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
@ -431,7 +431,7 @@ class NextDiT(nn.Module):
|
||||
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
||||
self.cap_embedder = nn.Sequential(
|
||||
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings),
|
||||
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
cap_feat_dim,
|
||||
dim,
|
||||
@ -457,7 +457,7 @@ class NextDiT(nn.Module):
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
||||
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
||||
|
||||
assert (dim // n_heads) == sum(axes_dims)
|
||||
|
@ -9,7 +9,6 @@ from einops import repeat
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
|
||||
@ -49,8 +48,8 @@ class WanSelfAttention(nn.Module):
|
||||
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, freqs):
|
||||
r"""
|
||||
@ -114,7 +113,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
||||
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, context, context_img_len):
|
||||
r"""
|
||||
@ -631,6 +630,7 @@ class VaceWanModel(WanModel):
|
||||
if ii is not None:
|
||||
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||
x += c_skip * vace_strength
|
||||
del c_skip
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
|
@ -279,6 +279,13 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
|
||||
|
||||
if isinstance(model, comfy.model_base.HiDream):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model."):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
@ -38,6 +38,8 @@ import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.ace.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -786,8 +788,8 @@ class PixArt(BaseModel):
|
||||
return out
|
||||
|
||||
class Flux(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=unet_model)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
try:
|
||||
@ -1104,4 +1106,38 @@ class HiDream(BaseModel):
|
||||
conditioning_llama3 = kwargs.get("conditioning_llama3", None)
|
||||
if conditioning_llama3 is not None:
|
||||
out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3)
|
||||
image_cond = kwargs.get("concat_latent_image", None)
|
||||
if image_cond is not None:
|
||||
out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond))
|
||||
return out
|
||||
|
||||
class Chroma(Flux):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
|
||||
guidance = kwargs.get("guidance", 0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class ACEStep(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
|
||||
if cross_attn is not None:
|
||||
out['lyric_token_idx'] = comfy.conds.CONDRegular(conditioning_lyrics)
|
||||
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
|
||||
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
||||
return out
|
||||
|
@ -164,7 +164,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if in_key in state_dict_keys:
|
||||
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["vec_in_dim"] = 768
|
||||
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
||||
if vec_in_key in state_dict_keys:
|
||||
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
@ -174,7 +176,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
dit_config["theta"] = 10000
|
||||
dit_config["qkv_bias"] = True
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
||||
dit_config["image_model"] = "chroma"
|
||||
dit_config["in_channels"] = 64
|
||||
dit_config["out_channels"] = 64
|
||||
dit_config["in_dim"] = 64
|
||||
dit_config["out_dim"] = 3072
|
||||
dit_config["hidden_dim"] = 5120
|
||||
dit_config["n_layers"] = 5
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
||||
@ -211,10 +222,39 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ltxv"
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
|
||||
dit_config["attention_head_dim"] = shape[0] // 32
|
||||
dit_config["cross_attention_dim"] = shape[1]
|
||||
if metadata is not None and "config" in metadata:
|
||||
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
|
||||
return dit_config
|
||||
|
||||
if '{}genre_embedder.weight'.format(key_prefix) in state_dict_keys: #ACE-Step model
|
||||
dit_config = {}
|
||||
dit_config["audio_model"] = "ace"
|
||||
dit_config["attention_head_dim"] = 128
|
||||
dit_config["in_channels"] = 8
|
||||
dit_config["inner_dim"] = 2560
|
||||
dit_config["max_height"] = 16
|
||||
dit_config["max_position"] = 32768
|
||||
dit_config["max_width"] = 32768
|
||||
dit_config["mlp_ratio"] = 2.5
|
||||
dit_config["num_attention_heads"] = 20
|
||||
dit_config["num_layers"] = 24
|
||||
dit_config["out_channels"] = 8
|
||||
dit_config["patch_size"] = [16, 1]
|
||||
dit_config["rope_theta"] = 1000000.0
|
||||
dit_config["speaker_embedding_dim"] = 512
|
||||
dit_config["text_embedding_dim"] = 768
|
||||
|
||||
dit_config["ssl_encoder_depths"] = [8, 8]
|
||||
dit_config["ssl_latent_dims"] = [1024, 768]
|
||||
dit_config["ssl_names"] = ["mert", "m-hubert"]
|
||||
dit_config["lyric_encoder_vocab_size"] = 6693
|
||||
dit_config["lyric_hidden_size"] = 1024
|
||||
return dit_config
|
||||
|
||||
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
|
||||
patch_size = 2
|
||||
dit_config = {}
|
||||
|
@ -967,15 +967,61 @@ def force_channels_last():
|
||||
#TODO
|
||||
return False
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||
|
||||
STREAMS = {}
|
||||
NUM_STREAMS = 1
|
||||
if args.async_offload:
|
||||
NUM_STREAMS = 2
|
||||
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
||||
|
||||
stream_counters = {}
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
if NUM_STREAMS <= 1:
|
||||
return None
|
||||
|
||||
if device in STREAMS:
|
||||
ss = STREAMS[device]
|
||||
s = ss[stream_counter]
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
if is_device_cuda(device):
|
||||
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
elif is_device_cuda(device):
|
||||
ss = []
|
||||
for k in range(NUM_STREAMS):
|
||||
ss.append(torch.cuda.Stream(device=device, priority=0))
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
return None
|
||||
|
||||
def sync_stream(device, stream):
|
||||
if stream is None:
|
||||
return
|
||||
if is_device_cuda(device):
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
if stream is not None:
|
||||
with stream:
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
if stream is not None:
|
||||
with stream:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
else:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
return r
|
||||
|
||||
def cast_to_device(tensor, device, dtype, copy=False):
|
||||
|
@ -111,13 +111,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
self.zsnr = zsnr
|
||||
|
||||
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
||||
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
if zsnr:
|
||||
if self.zsnr:
|
||||
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
|
||||
|
||||
self.set_sigmas(sigmas)
|
||||
|
28
comfy/ops.py
28
comfy/ops.py
@ -22,6 +22,7 @@ import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import contextlib
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
@ -37,20 +38,31 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
else:
|
||||
wf_context = contextlib.nullcontext()
|
||||
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
has_function = len(s.bias_function) > 0
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||
|
||||
if has_function:
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
with wf_context:
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
has_function = len(s.weight_function) > 0
|
||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||
if has_function:
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
with wf_context:
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
return weight, bias
|
||||
|
||||
class CastWeightBiasOp:
|
||||
@ -296,10 +308,10 @@ def fp8_linear(self, input):
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
input = input.reshape(-1, input_shape[2]).to(dtype)
|
||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
|
||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
|
@ -903,7 +903,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||
"gradient_estimation", "er_sde", "seeds_2", "seeds_3"]
|
||||
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||
|
73
comfy/sd.py
73
comfy/sd.py
@ -15,6 +15,7 @@ import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import yaml
|
||||
import math
|
||||
|
||||
@ -42,6 +43,7 @@ import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.hidream
|
||||
import comfy.text_encoders.ace
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -120,6 +122,7 @@ class CLIP:
|
||||
self.layer_idx = None
|
||||
self.use_clip_schedule = False
|
||||
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||
self.tokenizer_options = {}
|
||||
|
||||
def clone(self):
|
||||
n = CLIP(no_init=True)
|
||||
@ -127,6 +130,7 @@ class CLIP:
|
||||
n.cond_stage_model = self.cond_stage_model
|
||||
n.tokenizer = self.tokenizer
|
||||
n.layer_idx = self.layer_idx
|
||||
n.tokenizer_options = self.tokenizer_options.copy()
|
||||
n.use_clip_schedule = self.use_clip_schedule
|
||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||
return n
|
||||
@ -134,10 +138,18 @@ class CLIP:
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||
|
||||
def set_tokenizer_option(self, option_name, value):
|
||||
self.tokenizer_options[option_name] = value
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def tokenize(self, text, return_word_ids=False, **kwargs):
|
||||
tokenizer_options = kwargs.get("tokenizer_options", {})
|
||||
if len(self.tokenizer_options) > 0:
|
||||
tokenizer_options = {**self.tokenizer_options, **tokenizer_options}
|
||||
if len(tokenizer_options) > 0:
|
||||
kwargs["tokenizer_options"] = tokenizer_options
|
||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
|
||||
def add_hooks_to_dict(self, pooled_dict: dict[str]):
|
||||
@ -270,6 +282,7 @@ class VAE:
|
||||
|
||||
self.downscale_index_formula = None
|
||||
self.upscale_index_formula = None
|
||||
self.extra_1d_channel = None
|
||||
|
||||
if config is None:
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
@ -427,6 +440,20 @@ class VAE:
|
||||
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
|
||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
||||
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
|
||||
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 8
|
||||
self.output_channels = 2
|
||||
self.upscale_ratio = 4096
|
||||
self.downscale_ratio = 4096
|
||||
self.latent_dim = 2
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.disable_offload = True
|
||||
self.extra_1d_channel = 16
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@ -485,7 +512,13 @@ class VAE:
|
||||
return output
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
if samples.ndim == 3:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
else:
|
||||
og_shape = samples.shape
|
||||
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
|
||||
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
|
||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||
@ -505,9 +538,24 @@ class VAE:
|
||||
samples /= 3.0
|
||||
return samples
|
||||
|
||||
def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
||||
if self.latent_dim == 1:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
out_channels = self.latent_channels
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
else:
|
||||
extra_channel_size = self.extra_1d_channel
|
||||
out_channels = self.latent_channels * extra_channel_size
|
||||
tile_x = tile_x // extra_channel_size
|
||||
overlap = overlap // extra_channel_size
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
|
||||
|
||||
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||
if self.latent_dim == 1:
|
||||
return out
|
||||
else:
|
||||
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
|
||||
|
||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
@ -532,7 +580,7 @@ class VAE:
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1:
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
@ -599,7 +647,7 @@ class VAE:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1:
|
||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
@ -704,6 +752,8 @@ class CLIPType(Enum):
|
||||
LUMINA2 = 12
|
||||
WAN = 13
|
||||
HIDREAM = 14
|
||||
CHROMA = 15
|
||||
ACE = 16
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@ -808,7 +858,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.LTXV:
|
||||
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
|
||||
elif clip_type == CLIPType.PIXART:
|
||||
elif clip_type == CLIPType.PIXART or clip_type == CLIPType.CHROMA:
|
||||
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
|
||||
elif clip_type == CLIPType.WAN:
|
||||
@ -829,8 +879,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
|
||||
elif te_model == TEModel.T5_BASE:
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
if clip_type == CLIPType.ACE or "spiece_model" in clip_data[0]:
|
||||
clip_target.clip = comfy.text_encoders.ace.AceT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.ace.AceT5Tokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
elif te_model == TEModel.GEMMA_2_2B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
|
@ -457,13 +457,14 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
return embed_out
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}, tokenizer_args={}):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
|
||||
self.min_length = min_length
|
||||
self.end_token = None
|
||||
self.min_padding = min_padding
|
||||
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
self.tokenizer_adds_end_token = has_end_token
|
||||
@ -518,13 +519,15 @@ class SDTokenizer:
|
||||
return (embed, leftover)
|
||||
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
|
||||
'''
|
||||
Takes a prompt and converts it to a list of (token, weight, word id) elements.
|
||||
Tokens can both be integer tokens and pre computed CLIP tensors.
|
||||
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
|
||||
Returned list has the dimensions NxM where M is the input size of CLIP
|
||||
'''
|
||||
min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
|
||||
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
||||
|
||||
text = escape_important(text)
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
@ -603,10 +606,12 @@ class SDTokenizer:
|
||||
#fill last batch
|
||||
if self.end_token is not None:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
if min_padding is not None:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
|
||||
if self.pad_to_max_length and len(batch) < self.max_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
if self.min_length is not None and len(batch) < self.min_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
||||
if min_length is not None and len(batch) < min_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
|
||||
|
||||
if not return_word_ids:
|
||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||
@ -634,7 +639,7 @@ class SD1Tokenizer:
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
|
||||
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
|
@ -28,8 +28,8 @@ class SDXLTokenizer:
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
|
@ -17,6 +17,7 @@ import comfy.text_encoders.hunyuan_video
|
||||
import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.ace
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -785,6 +786,10 @@ class LTXV(supported_models_base.BASE):
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (unet_config.get("cross_attention_dim", 2048) / 2048) * 5.5
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LTXV(self, device=device)
|
||||
return out
|
||||
@ -993,6 +998,10 @@ class WAN21_Vace(WAN21_T2V):
|
||||
"model_type": "vace",
|
||||
}
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = 1.2 * self.memory_usage_factor
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
||||
return out
|
||||
@ -1064,7 +1073,62 @@ class HiDream(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return None # TODO
|
||||
|
||||
class Chroma(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "chroma",
|
||||
}
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream]
|
||||
unet_extra_config = {
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
}
|
||||
|
||||
latent_format = comfy.latent_formats.Flux
|
||||
|
||||
memory_usage_factor = 3.2
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Chroma(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||
|
||||
class ACEStep(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"audio_model": "ace",
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
latent_format = comfy.latent_formats.ACEAudio
|
||||
|
||||
memory_usage_factor = 0.5
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.ACEStep(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
153
comfy/text_encoders/ace.py
Normal file
153
comfy/text_encoders/ace.py
Normal file
@ -0,0 +1,153 @@
|
||||
from comfy import sd1_clip
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.t5
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from tokenizers import Tokenizer
|
||||
from .ace_text_cleaners import multilingual_cleaners, japanese_to_romaji
|
||||
|
||||
SUPPORT_LANGUAGES = {
|
||||
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
|
||||
"pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
|
||||
"nl": 297, "ar": 5022, "zh": 5023, "ja": 5412, "hu": 5753,
|
||||
"ko": 6152, "hi": 6680
|
||||
}
|
||||
|
||||
structure_pattern = re.compile(r"\[.*?\]")
|
||||
|
||||
DEFAULT_VOCAB_FILE = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
|
||||
|
||||
|
||||
class VoiceBpeTokenizer:
|
||||
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
|
||||
self.tokenizer = None
|
||||
if vocab_file is not None:
|
||||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||
|
||||
def preprocess_text(self, txt, lang):
|
||||
txt = multilingual_cleaners(txt, lang)
|
||||
return txt
|
||||
|
||||
def encode(self, txt, lang='en'):
|
||||
# lang = lang.split("-")[0] # remove the region
|
||||
# self.check_input_length(txt, lang)
|
||||
txt = self.preprocess_text(txt, lang)
|
||||
lang = "zh-cn" if lang == "zh" else lang
|
||||
txt = f"[{lang}]{txt}"
|
||||
txt = txt.replace(" ", "[SPACE]")
|
||||
return self.tokenizer.encode(txt).ids
|
||||
|
||||
def get_lang(self, line):
|
||||
if line.startswith("[") and line[3:4] == ']':
|
||||
lang = line[1:3].lower()
|
||||
if lang in SUPPORT_LANGUAGES:
|
||||
return lang, line[4:]
|
||||
return "en", line
|
||||
|
||||
def __call__(self, string):
|
||||
lines = string.split("\n")
|
||||
lyric_token_idx = [261]
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
lyric_token_idx += [2]
|
||||
continue
|
||||
|
||||
lang, line = self.get_lang(line)
|
||||
|
||||
if lang not in SUPPORT_LANGUAGES:
|
||||
lang = "en"
|
||||
if "zh" in lang:
|
||||
lang = "zh"
|
||||
if "spa" in lang:
|
||||
lang = "es"
|
||||
|
||||
try:
|
||||
line_out = japanese_to_romaji(line)
|
||||
if line_out != line:
|
||||
lang = "ja"
|
||||
line = line_out
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if structure_pattern.match(line):
|
||||
token_idx = self.encode(line, "en")
|
||||
else:
|
||||
token_idx = self.encode(line, lang)
|
||||
lyric_token_idx = lyric_token_idx + token_idx + [2]
|
||||
except Exception as e:
|
||||
logging.warning("tokenize error {} for line {} major_language {}".format(e, line, lang))
|
||||
return {"input_ids": lyric_token_idx}
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path, **kwargs):
|
||||
return VoiceBpeTokenizer(path, **kwargs)
|
||||
|
||||
def get_vocab(self):
|
||||
return {}
|
||||
|
||||
|
||||
class UMT5BaseModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_base.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=False, model_options=model_options)
|
||||
|
||||
class UMT5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=768, embedding_key='umt5base', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
class LyricsTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='lyrics', tokenizer_class=VoiceBpeTokenizer, has_start_token=True, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=2, has_end_token=False, tokenizer_data=tokenizer_data)
|
||||
|
||||
class AceT5Tokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.voicebpe = LyricsTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.umt5base = UMT5BaseTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["lyrics"] = self.voicebpe.tokenize_with_weights(kwargs.get("lyrics", ""), return_word_ids, **kwargs)
|
||||
out["umt5base"] = self.umt5base.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.umt5base.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return self.umt5base.state_dict()
|
||||
|
||||
class AceT5Model(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__()
|
||||
self.umt5base = UMT5BaseModel(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes = set()
|
||||
if dtype is not None:
|
||||
self.dtypes.add(dtype)
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.umt5base.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.umt5base.reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_umt5base = token_weight_pairs["umt5base"]
|
||||
token_weight_pairs_lyrics = token_weight_pairs["lyrics"]
|
||||
|
||||
t5_out, t5_pooled = self.umt5base.encode_token_weights(token_weight_pairs_umt5base)
|
||||
|
||||
lyrics_embeds = torch.tensor(list(map(lambda a: a[0], token_weight_pairs_lyrics[0]))).unsqueeze(0)
|
||||
return t5_out, None, {"conditioning_lyrics": lyrics_embeds}
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.umt5base.load_sd(sd)
|
15535
comfy/text_encoders/ace_lyrics_tokenizer/vocab.json
Normal file
15535
comfy/text_encoders/ace_lyrics_tokenizer/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
395
comfy/text_encoders/ace_text_cleaners.py
Normal file
395
comfy/text_encoders/ace_text_cleaners.py
Normal file
@ -0,0 +1,395 @@
|
||||
# basic text cleaners for the ACE step model
|
||||
# I didn't copy the ones from the reference code because I didn't want to deal with the dependencies
|
||||
# TODO: more languages than english?
|
||||
|
||||
import re
|
||||
|
||||
def japanese_to_romaji(japanese_text):
|
||||
"""
|
||||
Convert Japanese hiragana and katakana to romaji (Latin alphabet representation).
|
||||
|
||||
Args:
|
||||
japanese_text (str): Text containing hiragana and/or katakana characters
|
||||
|
||||
Returns:
|
||||
str: The romaji (Latin alphabet) equivalent
|
||||
"""
|
||||
# Dictionary mapping kana characters to their romaji equivalents
|
||||
kana_map = {
|
||||
# Katakana characters
|
||||
'ア': 'a', 'イ': 'i', 'ウ': 'u', 'エ': 'e', 'オ': 'o',
|
||||
'カ': 'ka', 'キ': 'ki', 'ク': 'ku', 'ケ': 'ke', 'コ': 'ko',
|
||||
'サ': 'sa', 'シ': 'shi', 'ス': 'su', 'セ': 'se', 'ソ': 'so',
|
||||
'タ': 'ta', 'チ': 'chi', 'ツ': 'tsu', 'テ': 'te', 'ト': 'to',
|
||||
'ナ': 'na', 'ニ': 'ni', 'ヌ': 'nu', 'ネ': 'ne', 'ノ': 'no',
|
||||
'ハ': 'ha', 'ヒ': 'hi', 'フ': 'fu', 'ヘ': 'he', 'ホ': 'ho',
|
||||
'マ': 'ma', 'ミ': 'mi', 'ム': 'mu', 'メ': 'me', 'モ': 'mo',
|
||||
'ヤ': 'ya', 'ユ': 'yu', 'ヨ': 'yo',
|
||||
'ラ': 'ra', 'リ': 'ri', 'ル': 'ru', 'レ': 're', 'ロ': 'ro',
|
||||
'ワ': 'wa', 'ヲ': 'wo', 'ン': 'n',
|
||||
|
||||
# Katakana voiced consonants
|
||||
'ガ': 'ga', 'ギ': 'gi', 'グ': 'gu', 'ゲ': 'ge', 'ゴ': 'go',
|
||||
'ザ': 'za', 'ジ': 'ji', 'ズ': 'zu', 'ゼ': 'ze', 'ゾ': 'zo',
|
||||
'ダ': 'da', 'ヂ': 'ji', 'ヅ': 'zu', 'デ': 'de', 'ド': 'do',
|
||||
'バ': 'ba', 'ビ': 'bi', 'ブ': 'bu', 'ベ': 'be', 'ボ': 'bo',
|
||||
'パ': 'pa', 'ピ': 'pi', 'プ': 'pu', 'ペ': 'pe', 'ポ': 'po',
|
||||
|
||||
# Katakana combinations
|
||||
'キャ': 'kya', 'キュ': 'kyu', 'キョ': 'kyo',
|
||||
'シャ': 'sha', 'シュ': 'shu', 'ショ': 'sho',
|
||||
'チャ': 'cha', 'チュ': 'chu', 'チョ': 'cho',
|
||||
'ニャ': 'nya', 'ニュ': 'nyu', 'ニョ': 'nyo',
|
||||
'ヒャ': 'hya', 'ヒュ': 'hyu', 'ヒョ': 'hyo',
|
||||
'ミャ': 'mya', 'ミュ': 'myu', 'ミョ': 'myo',
|
||||
'リャ': 'rya', 'リュ': 'ryu', 'リョ': 'ryo',
|
||||
'ギャ': 'gya', 'ギュ': 'gyu', 'ギョ': 'gyo',
|
||||
'ジャ': 'ja', 'ジュ': 'ju', 'ジョ': 'jo',
|
||||
'ビャ': 'bya', 'ビュ': 'byu', 'ビョ': 'byo',
|
||||
'ピャ': 'pya', 'ピュ': 'pyu', 'ピョ': 'pyo',
|
||||
|
||||
# Katakana small characters and special cases
|
||||
'ッ': '', # Small tsu (doubles the following consonant)
|
||||
'ャ': 'ya', 'ュ': 'yu', 'ョ': 'yo',
|
||||
|
||||
# Katakana extras
|
||||
'ヴ': 'vu', 'ファ': 'fa', 'フィ': 'fi', 'フェ': 'fe', 'フォ': 'fo',
|
||||
'ウィ': 'wi', 'ウェ': 'we', 'ウォ': 'wo',
|
||||
|
||||
# Hiragana characters
|
||||
'あ': 'a', 'い': 'i', 'う': 'u', 'え': 'e', 'お': 'o',
|
||||
'か': 'ka', 'き': 'ki', 'く': 'ku', 'け': 'ke', 'こ': 'ko',
|
||||
'さ': 'sa', 'し': 'shi', 'す': 'su', 'せ': 'se', 'そ': 'so',
|
||||
'た': 'ta', 'ち': 'chi', 'つ': 'tsu', 'て': 'te', 'と': 'to',
|
||||
'な': 'na', 'に': 'ni', 'ぬ': 'nu', 'ね': 'ne', 'の': 'no',
|
||||
'は': 'ha', 'ひ': 'hi', 'ふ': 'fu', 'へ': 'he', 'ほ': 'ho',
|
||||
'ま': 'ma', 'み': 'mi', 'む': 'mu', 'め': 'me', 'も': 'mo',
|
||||
'や': 'ya', 'ゆ': 'yu', 'よ': 'yo',
|
||||
'ら': 'ra', 'り': 'ri', 'る': 'ru', 'れ': 're', 'ろ': 'ro',
|
||||
'わ': 'wa', 'を': 'wo', 'ん': 'n',
|
||||
|
||||
# Hiragana voiced consonants
|
||||
'が': 'ga', 'ぎ': 'gi', 'ぐ': 'gu', 'げ': 'ge', 'ご': 'go',
|
||||
'ざ': 'za', 'じ': 'ji', 'ず': 'zu', 'ぜ': 'ze', 'ぞ': 'zo',
|
||||
'だ': 'da', 'ぢ': 'ji', 'づ': 'zu', 'で': 'de', 'ど': 'do',
|
||||
'ば': 'ba', 'び': 'bi', 'ぶ': 'bu', 'べ': 'be', 'ぼ': 'bo',
|
||||
'ぱ': 'pa', 'ぴ': 'pi', 'ぷ': 'pu', 'ぺ': 'pe', 'ぽ': 'po',
|
||||
|
||||
# Hiragana combinations
|
||||
'きゃ': 'kya', 'きゅ': 'kyu', 'きょ': 'kyo',
|
||||
'しゃ': 'sha', 'しゅ': 'shu', 'しょ': 'sho',
|
||||
'ちゃ': 'cha', 'ちゅ': 'chu', 'ちょ': 'cho',
|
||||
'にゃ': 'nya', 'にゅ': 'nyu', 'にょ': 'nyo',
|
||||
'ひゃ': 'hya', 'ひゅ': 'hyu', 'ひょ': 'hyo',
|
||||
'みゃ': 'mya', 'みゅ': 'myu', 'みょ': 'myo',
|
||||
'りゃ': 'rya', 'りゅ': 'ryu', 'りょ': 'ryo',
|
||||
'ぎゃ': 'gya', 'ぎゅ': 'gyu', 'ぎょ': 'gyo',
|
||||
'じゃ': 'ja', 'じゅ': 'ju', 'じょ': 'jo',
|
||||
'びゃ': 'bya', 'びゅ': 'byu', 'びょ': 'byo',
|
||||
'ぴゃ': 'pya', 'ぴゅ': 'pyu', 'ぴょ': 'pyo',
|
||||
|
||||
# Hiragana small characters and special cases
|
||||
'っ': '', # Small tsu (doubles the following consonant)
|
||||
'ゃ': 'ya', 'ゅ': 'yu', 'ょ': 'yo',
|
||||
|
||||
# Common punctuation and spaces
|
||||
' ': ' ', # Japanese space
|
||||
'、': ', ', '。': '. ',
|
||||
}
|
||||
|
||||
result = []
|
||||
i = 0
|
||||
|
||||
while i < len(japanese_text):
|
||||
# Check for small tsu (doubling the following consonant)
|
||||
if i < len(japanese_text) - 1 and (japanese_text[i] == 'っ' or japanese_text[i] == 'ッ'):
|
||||
if i < len(japanese_text) - 1 and japanese_text[i+1] in kana_map:
|
||||
next_romaji = kana_map[japanese_text[i+1]]
|
||||
if next_romaji and next_romaji[0] not in 'aiueon':
|
||||
result.append(next_romaji[0]) # Double the consonant
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Check for combinations with small ya, yu, yo
|
||||
if i < len(japanese_text) - 1 and japanese_text[i+1] in ('ゃ', 'ゅ', 'ょ', 'ャ', 'ュ', 'ョ'):
|
||||
combo = japanese_text[i:i+2]
|
||||
if combo in kana_map:
|
||||
result.append(kana_map[combo])
|
||||
i += 2
|
||||
continue
|
||||
|
||||
# Regular character
|
||||
if japanese_text[i] in kana_map:
|
||||
result.append(kana_map[japanese_text[i]])
|
||||
else:
|
||||
# If it's not in our map, keep it as is (might be kanji, romaji, etc.)
|
||||
result.append(japanese_text[i])
|
||||
|
||||
i += 1
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
def number_to_text(num, ordinal=False):
|
||||
"""
|
||||
Convert a number (int or float) to its text representation.
|
||||
|
||||
Args:
|
||||
num: The number to convert
|
||||
|
||||
Returns:
|
||||
str: Text representation of the number
|
||||
"""
|
||||
|
||||
if not isinstance(num, (int, float)):
|
||||
return "Input must be a number"
|
||||
|
||||
# Handle special case of zero
|
||||
if num == 0:
|
||||
return "zero"
|
||||
|
||||
# Handle negative numbers
|
||||
negative = num < 0
|
||||
num = abs(num)
|
||||
|
||||
# Handle floats
|
||||
if isinstance(num, float):
|
||||
# Split into integer and decimal parts
|
||||
int_part = int(num)
|
||||
|
||||
# Convert both parts
|
||||
int_text = _int_to_text(int_part)
|
||||
|
||||
# Handle decimal part (convert to string and remove '0.')
|
||||
decimal_str = str(num).split('.')[1]
|
||||
decimal_text = " point " + " ".join(_digit_to_text(int(digit)) for digit in decimal_str)
|
||||
|
||||
result = int_text + decimal_text
|
||||
else:
|
||||
# Handle integers
|
||||
result = _int_to_text(num)
|
||||
|
||||
# Add 'negative' prefix for negative numbers
|
||||
if negative:
|
||||
result = "negative " + result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _int_to_text(num):
|
||||
"""Helper function to convert an integer to text"""
|
||||
|
||||
ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
|
||||
"ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen",
|
||||
"seventeen", "eighteen", "nineteen"]
|
||||
|
||||
tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]
|
||||
|
||||
if num < 20:
|
||||
return ones[num]
|
||||
|
||||
if num < 100:
|
||||
return tens[num // 10] + (" " + ones[num % 10] if num % 10 != 0 else "")
|
||||
|
||||
if num < 1000:
|
||||
return ones[num // 100] + " hundred" + (" " + _int_to_text(num % 100) if num % 100 != 0 else "")
|
||||
|
||||
if num < 1000000:
|
||||
return _int_to_text(num // 1000) + " thousand" + (" " + _int_to_text(num % 1000) if num % 1000 != 0 else "")
|
||||
|
||||
if num < 1000000000:
|
||||
return _int_to_text(num // 1000000) + " million" + (" " + _int_to_text(num % 1000000) if num % 1000000 != 0 else "")
|
||||
|
||||
return _int_to_text(num // 1000000000) + " billion" + (" " + _int_to_text(num % 1000000000) if num % 1000000000 != 0 else "")
|
||||
|
||||
|
||||
def _digit_to_text(digit):
|
||||
"""Convert a single digit to text"""
|
||||
digits = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
||||
return digits[digit]
|
||||
|
||||
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = {
|
||||
"en": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mrs", "misess"),
|
||||
("mr", "mister"),
|
||||
("dr", "doctor"),
|
||||
("st", "saint"),
|
||||
("co", "company"),
|
||||
("jr", "junior"),
|
||||
("maj", "major"),
|
||||
("gen", "general"),
|
||||
("drs", "doctors"),
|
||||
("rev", "reverend"),
|
||||
("lt", "lieutenant"),
|
||||
("hon", "honorable"),
|
||||
("sgt", "sergeant"),
|
||||
("capt", "captain"),
|
||||
("esq", "esquire"),
|
||||
("ltd", "limited"),
|
||||
("col", "colonel"),
|
||||
("ft", "fort"),
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def expand_abbreviations_multilingual(text, lang="en"):
|
||||
for regex, replacement in _abbreviations[lang]:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
_symbols_multilingual = {
|
||||
"en": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " and "),
|
||||
("@", " at "),
|
||||
("%", " percent "),
|
||||
("#", " hash "),
|
||||
("$", " dollar "),
|
||||
("£", " pound "),
|
||||
("°", " degree "),
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def expand_symbols_multilingual(text, lang="en"):
|
||||
for regex, replacement in _symbols_multilingual[lang]:
|
||||
text = re.sub(regex, replacement, text)
|
||||
text = text.replace(" ", " ") # Ensure there are no double spaces
|
||||
return text.strip()
|
||||
|
||||
|
||||
_ordinal_re = {
|
||||
"en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
|
||||
}
|
||||
_number_re = re.compile(r"[0-9]+")
|
||||
_currency_re = {
|
||||
"USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
|
||||
"GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
|
||||
"EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
|
||||
}
|
||||
|
||||
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
|
||||
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
|
||||
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
text = m.group(0)
|
||||
if "," in text:
|
||||
text = text.replace(",", "")
|
||||
return text
|
||||
|
||||
|
||||
def _remove_dots(m):
|
||||
text = m.group(0)
|
||||
if "." in text:
|
||||
text = text.replace(".", "")
|
||||
return text
|
||||
|
||||
|
||||
def _expand_decimal_point(m, lang="en"):
|
||||
amount = m.group(1).replace(",", ".")
|
||||
return number_to_text(float(amount))
|
||||
|
||||
|
||||
def _expand_currency(m, lang="en", currency="USD"):
|
||||
amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
|
||||
full_amount = number_to_text(amount)
|
||||
|
||||
and_equivalents = {
|
||||
"en": ", ",
|
||||
"es": " con ",
|
||||
"fr": " et ",
|
||||
"de": " und ",
|
||||
"pt": " e ",
|
||||
"it": " e ",
|
||||
"pl": ", ",
|
||||
"cs": ", ",
|
||||
"ru": ", ",
|
||||
"nl": ", ",
|
||||
"ar": ", ",
|
||||
"tr": ", ",
|
||||
"hu": ", ",
|
||||
"ko": ", ",
|
||||
}
|
||||
|
||||
if amount.is_integer():
|
||||
last_and = full_amount.rfind(and_equivalents[lang])
|
||||
if last_and != -1:
|
||||
full_amount = full_amount[:last_and]
|
||||
|
||||
return full_amount
|
||||
|
||||
|
||||
def _expand_ordinal(m, lang="en"):
|
||||
return number_to_text(int(m.group(1)), ordinal=True)
|
||||
|
||||
|
||||
def _expand_number(m, lang="en"):
|
||||
return number_to_text(int(m.group(0)))
|
||||
|
||||
|
||||
def expand_numbers_multilingual(text, lang="en"):
|
||||
if lang in ["en", "ru"]:
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
else:
|
||||
text = re.sub(_dot_number_re, _remove_dots, text)
|
||||
try:
|
||||
text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
|
||||
text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
|
||||
text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
|
||||
except:
|
||||
pass
|
||||
|
||||
text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
|
||||
text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
|
||||
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
|
||||
return text
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, " ", text)
|
||||
|
||||
|
||||
def multilingual_cleaners(text, lang):
|
||||
text = text.replace('"', "")
|
||||
if lang == "tr":
|
||||
text = text.replace("İ", "i")
|
||||
text = text.replace("Ö", "ö")
|
||||
text = text.replace("Ü", "ü")
|
||||
text = lowercase(text)
|
||||
try:
|
||||
text = expand_numbers_multilingual(text, lang)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
text = expand_abbreviations_multilingual(text, lang)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
text = expand_symbols_multilingual(text, lang=lang)
|
||||
except:
|
||||
pass
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
@ -19,8 +19,8 @@ class FluxTokenizer:
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
|
@ -16,11 +16,11 @@ class HiDreamTokenizer:
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
|
||||
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
|
||||
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
|
@ -49,13 +49,13 @@ class HunyuanVideoTokenizer:
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
|
||||
out = {}
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids)
|
||||
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids, **kwargs)
|
||||
embed_count = 0
|
||||
for r in llama_text_tokens:
|
||||
for i in range(len(r)):
|
||||
|
@ -41,8 +41,8 @@ class HyditTokenizer:
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
|
||||
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)
|
||||
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
|
@ -45,9 +45,9 @@ class SD3Tokenizer:
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
|
22
comfy/text_encoders/umt5_config_base.json
Normal file
22
comfy/text_encoders/umt5_config_base.json
Normal file
@ -0,0 +1,22 @@
|
||||
{
|
||||
"d_ff": 2048,
|
||||
"d_kv": 64,
|
||||
"d_model": 768,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"dense_act_fn": "gelu_pytorch_tanh",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "umt5",
|
||||
"num_decoder_layers": 12,
|
||||
"num_heads": 12,
|
||||
"num_layers": 12,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"vocab_size": 256384
|
||||
}
|
@ -28,6 +28,9 @@ import logging
|
||||
import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
|
||||
ALWAYS_SAFE_LOAD = False
|
||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||
@ -67,8 +70,12 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
|
||||
raise e
|
||||
else:
|
||||
torch_args = {}
|
||||
if MMAP_TORCH_FILES:
|
||||
torch_args["mmap"] = True
|
||||
|
||||
if safe_load or ALWAYS_SAFE_LOAD:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||
if "global_step" in pl_sd:
|
||||
|
@ -24,7 +24,7 @@ class BOFTAdapter(WeightAdapterBase):
|
||||
) -> Optional["BOFTAdapter"]:
|
||||
if loaded_keys is None:
|
||||
loaded_keys = set()
|
||||
blocks_name = "{}.boft_blocks".format(x)
|
||||
blocks_name = "{}.oft_blocks".format(x)
|
||||
rescale_name = "{}.rescale".format(x)
|
||||
|
||||
blocks = None
|
||||
@ -32,17 +32,18 @@ class BOFTAdapter(WeightAdapterBase):
|
||||
blocks = lora[blocks_name]
|
||||
if blocks.ndim == 4:
|
||||
loaded_keys.add(blocks_name)
|
||||
else:
|
||||
blocks = None
|
||||
if blocks is None:
|
||||
return None
|
||||
|
||||
rescale = None
|
||||
if rescale_name in lora.keys():
|
||||
rescale = lora[rescale_name]
|
||||
loaded_keys.add(rescale_name)
|
||||
|
||||
if blocks is not None:
|
||||
weights = (blocks, rescale, alpha, dora_scale)
|
||||
return cls(loaded_keys, weights)
|
||||
else:
|
||||
return None
|
||||
weights = (blocks, rescale, alpha, dora_scale)
|
||||
return cls(loaded_keys, weights)
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
@ -71,7 +72,7 @@ class BOFTAdapter(WeightAdapterBase):
|
||||
# Get r
|
||||
I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype)
|
||||
# for Q = -Q^T
|
||||
q = blocks - blocks.transpose(1, 2)
|
||||
q = blocks - blocks.transpose(-1, -2)
|
||||
normed_q = q
|
||||
if alpha > 0: # alpha in boft/bboft is for constraint
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
@ -79,9 +80,8 @@ class BOFTAdapter(WeightAdapterBase):
|
||||
normed_q = q * alpha / q_norm
|
||||
# use float() to prevent unsupported type in .inverse()
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
r = r.to(original_weight)
|
||||
|
||||
inp = org = original_weight
|
||||
r = r.to(weight)
|
||||
inp = org = weight
|
||||
|
||||
r_b = boft_b//2
|
||||
for i in range(boft_m):
|
||||
@ -91,14 +91,14 @@ class BOFTAdapter(WeightAdapterBase):
|
||||
if strength != 1:
|
||||
bi = bi * strength + (1-strength) * I
|
||||
inp = (
|
||||
inp.unflatten(-1, (-1, g, k))
|
||||
.transpose(-2, -1)
|
||||
.flatten(-3)
|
||||
.unflatten(-1, (-1, boft_b))
|
||||
inp.unflatten(0, (-1, g, k))
|
||||
.transpose(1, 2)
|
||||
.flatten(0, 2)
|
||||
.unflatten(0, (-1, boft_b))
|
||||
)
|
||||
inp = torch.einsum("b n m, b n ... -> b m ...", inp, bi)
|
||||
inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp)
|
||||
inp = (
|
||||
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
|
||||
inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2)
|
||||
)
|
||||
|
||||
if rescale is not None:
|
||||
@ -109,7 +109,7 @@ class BOFTAdapter(WeightAdapterBase):
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
weight += function((strength * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
@ -32,17 +32,18 @@ class OFTAdapter(WeightAdapterBase):
|
||||
blocks = lora[blocks_name]
|
||||
if blocks.ndim == 3:
|
||||
loaded_keys.add(blocks_name)
|
||||
else:
|
||||
blocks = None
|
||||
if blocks is None:
|
||||
return None
|
||||
|
||||
rescale = None
|
||||
if rescale_name in lora.keys():
|
||||
rescale = lora[rescale_name]
|
||||
loaded_keys.add(rescale_name)
|
||||
|
||||
if blocks is not None:
|
||||
weights = (blocks, rescale, alpha, dora_scale)
|
||||
return cls(loaded_keys, weights)
|
||||
else:
|
||||
return None
|
||||
weights = (blocks, rescale, alpha, dora_scale)
|
||||
return cls(loaded_keys, weights)
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
@ -79,16 +80,17 @@ class OFTAdapter(WeightAdapterBase):
|
||||
normed_q = q * alpha / q_norm
|
||||
# use float() to prevent unsupported type in .inverse()
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
r = r.to(original_weight)
|
||||
r = r.to(weight)
|
||||
_, *shape = weight.shape
|
||||
lora_diff = torch.einsum(
|
||||
"k n m, k n ... -> k m ...",
|
||||
(r * strength) - strength * I,
|
||||
original_weight,
|
||||
)
|
||||
weight.view(block_num, block_size, *shape),
|
||||
).view(-1, *shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
weight += function((strength * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
8
comfy_api/input/__init__.py
Normal file
8
comfy_api/input/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from .basic_types import ImageInput, AudioInput
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
"ImageInput",
|
||||
"AudioInput",
|
||||
"VideoInput",
|
||||
]
|
20
comfy_api/input/basic_types.py
Normal file
20
comfy_api/input/basic_types.py
Normal file
@ -0,0 +1,20 @@
|
||||
import torch
|
||||
from typing import TypedDict
|
||||
|
||||
ImageInput = torch.Tensor
|
||||
"""
|
||||
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
|
||||
"""
|
||||
|
||||
class AudioInput(TypedDict):
|
||||
"""
|
||||
TypedDict representing audio input.
|
||||
"""
|
||||
|
||||
waveform: torch.Tensor
|
||||
"""
|
||||
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
|
||||
"""
|
||||
|
||||
sample_rate: int
|
||||
|
45
comfy_api/input/video_types.py
Normal file
45
comfy_api/input/video_types.py
Normal file
@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
Abstract base class for video input types.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_components(self) -> VideoComponents:
|
||||
"""
|
||||
Abstract method to get the video components (images, audio, and frame rate).
|
||||
|
||||
Returns:
|
||||
VideoComponents containing images, audio, and frame rate
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Abstract method to save the video input to a file.
|
||||
"""
|
||||
pass
|
||||
|
||||
# Provide a default implementation, but subclasses can provide optimized versions
|
||||
# if possible.
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
components = self.get_components()
|
||||
return components.images.shape[2], components.images.shape[1]
|
||||
|
7
comfy_api/input_impl/__init__.py
Normal file
7
comfy_api/input_impl/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from .video_types import VideoFromFile, VideoFromComponents
|
||||
|
||||
__all__ = [
|
||||
# Implementations
|
||||
"VideoFromFile",
|
||||
"VideoFromComponents",
|
||||
]
|
271
comfy_api/input_impl/video_types.py
Normal file
271
comfy_api/input_impl/video_types.py
Normal file
@ -0,0 +1,271 @@
|
||||
from __future__ import annotations
|
||||
from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.input import AudioInput
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
"""
|
||||
A container's `format` may be a comma-separated list of formats.
|
||||
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
|
||||
However, writing to a file/stream with `av.open` requires a single format,
|
||||
or `None` to auto-detect.
|
||||
"""
|
||||
if not container_format:
|
||||
return None # Auto-detect
|
||||
|
||||
if "," not in container_format:
|
||||
return container_format
|
||||
|
||||
formats = container_format.split(",")
|
||||
return formats[0]
|
||||
|
||||
|
||||
def get_open_write_kwargs(
|
||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||
) -> dict:
|
||||
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
|
||||
open_kwargs = {
|
||||
"mode": "w",
|
||||
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
|
||||
"options": {"movflags": "use_metadata_tags"},
|
||||
}
|
||||
|
||||
is_write_to_buffer = isinstance(dest, io.BytesIO)
|
||||
if is_write_to_buffer:
|
||||
# Set output format explicitly, since it cannot be inferred from file extension
|
||||
if to_format == VideoContainer.AUTO:
|
||||
to_format = container_format.lower()
|
||||
elif isinstance(to_format, str):
|
||||
to_format = to_format.lower()
|
||||
open_kwargs["format"] = container_to_output_format(to_format)
|
||||
|
||||
return open_kwargs
|
||||
|
||||
|
||||
class VideoFromFile(VideoInput):
|
||||
"""
|
||||
Class representing video input from a file.
|
||||
"""
|
||||
|
||||
def __init__(self, file: str | io.BytesIO):
|
||||
"""
|
||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
for stream in container.streams:
|
||||
if stream.type == 'video':
|
||||
assert isinstance(stream, av.VideoStream)
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
# Get video frames
|
||||
frames = []
|
||||
for frame in container.decode(video=0):
|
||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||
frames.append(img)
|
||||
|
||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||
|
||||
# Get frame rate
|
||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||
|
||||
# Get audio if available
|
||||
audio = None
|
||||
try:
|
||||
container.seek(0) # Reset the container to the beginning
|
||||
for stream in container.streams:
|
||||
if stream.type != 'audio':
|
||||
continue
|
||||
assert isinstance(stream, av.AudioStream)
|
||||
audio_frames = []
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
assert isinstance(frame, av.AudioFrame)
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||
})
|
||||
except StopIteration:
|
||||
pass # No audio stream
|
||||
|
||||
metadata = container.metadata
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
return self.get_components_internal(container)
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str | io.BytesIO,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
container_format = container.format.name
|
||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||
reuse_streams = True
|
||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path,
|
||||
format=format,
|
||||
codec=codec,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
|
||||
open_kwargs = get_open_write_kwargs(path, container_format, format)
|
||||
with av.open(path, **open_kwargs) as output_container:
|
||||
# Copy over the original metadata
|
||||
for key, value in container.metadata.items():
|
||||
if metadata is None or key not in metadata:
|
||||
output_container.metadata[key] = value
|
||||
|
||||
# Add our new metadata
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, str):
|
||||
output_container.metadata[key] = value
|
||||
else:
|
||||
output_container.metadata[key] = json.dumps(value)
|
||||
|
||||
# Add streams to the new container
|
||||
stream_map = {}
|
||||
for stream in streams:
|
||||
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
|
||||
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
|
||||
stream_map[stream] = out_stream
|
||||
|
||||
# Write packets to the new container
|
||||
for packet in container.demux():
|
||||
if packet.stream in stream_map and packet.dts is not None:
|
||||
packet.stream = stream_map[packet.stream]
|
||||
output_container.mux(packet)
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
"""
|
||||
Class representing video input from tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
self.__components = components
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
)
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
||||
# Add metadata before writing any streams
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
output.metadata[key] = json.dumps(value)
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
audio_stream: Optional[av.AudioStream] = None
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
audio_stream.sample_rate = audio_sample_rate
|
||||
audio_stream.format = 'fltp'
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
packet = video_stream.encode(frame)
|
||||
output.mux(packet)
|
||||
|
||||
# Flush video
|
||||
packet = video_stream.encode(None)
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
# Encode audio
|
||||
samples_per_frame = int(audio_sample_rate / frame_rate)
|
||||
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
|
||||
for i in range(num_frames):
|
||||
start = i * samples_per_frame
|
||||
end = start + samples_per_frame
|
||||
# TODO(Feature) - Add support for stereo audio
|
||||
chunk = (
|
||||
self.__components.audio["waveform"][0, 0, start:end]
|
||||
.unsqueeze(0)
|
||||
.contiguous()
|
||||
.numpy()
|
||||
)
|
||||
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
||||
audio_frame.sample_rate = audio_sample_rate
|
||||
audio_frame.pts = i * samples_per_frame
|
||||
for packet in audio_stream.encode(audio_frame):
|
||||
output.mux(packet)
|
||||
|
||||
# Flush audio
|
||||
for packet in audio_stream.encode(None):
|
||||
output.mux(packet)
|
||||
|
8
comfy_api/util/__init__.py
Normal file
8
comfy_api/util/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
"VideoContainer",
|
||||
"VideoCodec",
|
||||
"VideoComponents",
|
||||
]
|
51
comfy_api/util/video_types.py
Normal file
51
comfy_api/util/video_types.py
Normal file
@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.input import ImageInput, AudioInput
|
||||
|
||||
class VideoCodec(str, Enum):
|
||||
AUTO = "auto"
|
||||
H264 = "h264"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""
|
||||
Returns a list of codec names that can be used as node input.
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
class VideoContainer(str, Enum):
|
||||
AUTO = "auto"
|
||||
MP4 = "mp4"
|
||||
|
||||
@classmethod
|
||||
def as_input(cls) -> list[str]:
|
||||
"""
|
||||
Returns a list of container names that can be used as node input.
|
||||
"""
|
||||
return [member.value for member in cls]
|
||||
|
||||
@classmethod
|
||||
def get_extension(cls, value) -> str:
|
||||
"""
|
||||
Returns the file extension for the container.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
value = cls(value)
|
||||
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
|
||||
return "mp4"
|
||||
return ""
|
||||
|
||||
@dataclass
|
||||
class VideoComponents:
|
||||
"""
|
||||
Dataclass representing the components of a video.
|
||||
"""
|
||||
|
||||
images: ImageInput
|
||||
frame_rate: Fraction
|
||||
audio: Optional[AudioInput] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
41
comfy_api_nodes/README.md
Normal file
41
comfy_api_nodes/README.md
Normal file
@ -0,0 +1,41 @@
|
||||
# ComfyUI API Nodes
|
||||
|
||||
## Introduction
|
||||
|
||||
Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview#api-nodes).
|
||||
|
||||
## Development
|
||||
|
||||
While developing, you should be testing against the Staging environment. To test against staging:
|
||||
|
||||
**Install ComfyUI_frontend**
|
||||
|
||||
Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to start the frontend server. By default, it will connect to Staging authentication.
|
||||
|
||||
> **Hint:** If you use --front-end-version argument for ComfyUI, it will use production authentication.
|
||||
|
||||
```bash
|
||||
python run main.py --comfy-api-base https://stagingapi.comfy.org
|
||||
```
|
||||
|
||||
API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes.
|
||||
|
||||
### Redocly Instructions
|
||||
|
||||
**Tip**
|
||||
When developing locally, use the `redocly-dev.yaml` file to generate pydantic models. This lets you use stubs for APIs that are not marked `Released` yet.
|
||||
|
||||
Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging.
|
||||
|
||||
```bash
|
||||
# Download the OpenAPI file from prod server.
|
||||
curl -o openapi.yaml https://stagingapi.comfy.org/openapi
|
||||
|
||||
# Filter out unneeded API definitions.
|
||||
npm install -g @redocly/cli
|
||||
redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly-dev.yaml --remove-unused-components
|
||||
|
||||
# Generate the pydantic datamodels for validation.
|
||||
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
|
||||
|
||||
```
|
0
comfy_api_nodes/__init__.py
Normal file
0
comfy_api_nodes/__init__.py
Normal file
576
comfy_api_nodes/apinode_utils.py
Normal file
576
comfy_api_nodes/apinode_utils.py
Normal file
@ -0,0 +1,576 @@
|
||||
from __future__ import annotations
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.util import VideoContainer, VideoCodec
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.input.basic_types import AudioInput
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiClient,
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
UploadRequest,
|
||||
UploadResponse,
|
||||
)
|
||||
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
import math
|
||||
import base64
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
import av
|
||||
|
||||
|
||||
def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
|
||||
"""Downloads a video from a URL and returns a `VIDEO` output.
|
||||
|
||||
Args:
|
||||
video_url: The URL of the video to download.
|
||||
|
||||
Returns:
|
||||
A Comfy node `VIDEO` output.
|
||||
"""
|
||||
video_io = download_url_to_bytesio(video_url, timeout)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {video_url}"
|
||||
logging.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
return VideoFromFile(video_io)
|
||||
|
||||
|
||||
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(total_pixels)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
if scale_by >= 1:
|
||||
return image
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
||||
s = s.movedim(1, -1)
|
||||
return s
|
||||
|
||||
|
||||
def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
|
||||
"""Validates and casts a response to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
response: The response to validate and cast.
|
||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
|
||||
Raises:
|
||||
ValueError: If the response is not valid.
|
||||
"""
|
||||
# validate raw JSON response
|
||||
data = response.data
|
||||
if not data or len(data) == 0:
|
||||
raise ValueError("No images returned from API endpoint")
|
||||
|
||||
# Initialize list to store image tensors
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
|
||||
# Process each image in the data array
|
||||
for image_data in data:
|
||||
image_url = image_data.url
|
||||
b64_data = image_data.b64_json
|
||||
|
||||
if not image_url and not b64_data:
|
||||
raise ValueError("No image was generated in the response")
|
||||
|
||||
if b64_data:
|
||||
img_data = base64.b64decode(b64_data)
|
||||
img = Image.open(io.BytesIO(img_data))
|
||||
|
||||
elif image_url:
|
||||
img_response = requests.get(image_url, timeout=timeout)
|
||||
if img_response.status_code != 200:
|
||||
raise ValueError("Failed to download the image")
|
||||
img = Image.open(io.BytesIO(img_response.content))
|
||||
|
||||
img = img.convert("RGBA")
|
||||
|
||||
# Convert to numpy array, normalize to float32 between 0 and 1
|
||||
img_array = np.array(img).astype(np.float32) / 255.0
|
||||
img_tensor = torch.from_numpy(img_array)
|
||||
|
||||
# Add to list of tensors
|
||||
image_tensors.append(img_tensor)
|
||||
|
||||
return torch.stack(image_tensors, dim=0)
|
||||
|
||||
|
||||
def validate_aspect_ratio(
|
||||
aspect_ratio: str,
|
||||
minimum_ratio: float,
|
||||
maximum_ratio: float,
|
||||
minimum_ratio_str: str,
|
||||
maximum_ratio_str: str,
|
||||
) -> float:
|
||||
"""Validates and casts an aspect ratio string to a float.
|
||||
|
||||
Args:
|
||||
aspect_ratio: The aspect ratio string to validate.
|
||||
minimum_ratio: The minimum aspect ratio.
|
||||
maximum_ratio: The maximum aspect ratio.
|
||||
minimum_ratio_str: The minimum aspect ratio string.
|
||||
maximum_ratio_str: The maximum aspect ratio string.
|
||||
|
||||
Returns:
|
||||
The validated and cast aspect ratio.
|
||||
|
||||
Raises:
|
||||
Exception: If the aspect ratio is not valid.
|
||||
"""
|
||||
# get ratio values
|
||||
numbers = aspect_ratio.split(":")
|
||||
if len(numbers) != 2:
|
||||
raise TypeError(
|
||||
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
|
||||
)
|
||||
try:
|
||||
numerator = int(numbers[0])
|
||||
denominator = int(numbers[1])
|
||||
except ValueError as exc:
|
||||
raise TypeError(
|
||||
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
|
||||
) from exc
|
||||
calculated_ratio = numerator / denominator
|
||||
# if not close to minimum and maximum, check bounds
|
||||
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
|
||||
calculated_ratio, maximum_ratio
|
||||
):
|
||||
if calculated_ratio < minimum_ratio:
|
||||
raise TypeError(
|
||||
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
||||
)
|
||||
elif calculated_ratio > maximum_ratio:
|
||||
raise TypeError(
|
||||
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
||||
)
|
||||
return aspect_ratio
|
||||
|
||||
|
||||
def mimetype_to_extension(mime_type: str) -> str:
|
||||
"""Converts a MIME type to a file extension."""
|
||||
return mime_type.split("/")[-1].lower()
|
||||
|
||||
|
||||
def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
|
||||
"""Downloads content from a URL using requests and returns it as BytesIO.
|
||||
|
||||
Args:
|
||||
url: The URL to download.
|
||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||||
|
||||
Returns:
|
||||
BytesIO object containing the downloaded content.
|
||||
"""
|
||||
response = requests.get(url, stream=True, timeout=timeout)
|
||||
response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
||||
return BytesIO(response.content)
|
||||
|
||||
|
||||
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
||||
"""Converts image data from BytesIO to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
image_bytesio: BytesIO object containing the image data.
|
||||
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
|
||||
Raises:
|
||||
PIL.UnidentifiedImageError: If the image data cannot be identified.
|
||||
ValueError: If the specified mode is invalid.
|
||||
"""
|
||||
image = Image.open(image_bytesio)
|
||||
image = image.convert(mode)
|
||||
image_array = np.array(image).astype(np.float32) / 255.0
|
||||
return torch.from_numpy(image_array).unsqueeze(0)
|
||||
|
||||
|
||||
def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
||||
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
||||
image_bytesio = download_url_to_bytesio(url, timeout)
|
||||
return bytesio_to_image_tensor(image_bytesio)
|
||||
|
||||
def process_image_response(response: requests.Response) -> torch.Tensor:
|
||||
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
||||
return bytesio_to_image_tensor(BytesIO(response.content))
|
||||
|
||||
|
||||
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
||||
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
|
||||
if len(image.shape) > 3:
|
||||
image = image[0]
|
||||
# TODO: remove alpha if not allowed and present
|
||||
input_tensor = image.cpu()
|
||||
input_tensor = downscale_image_tensor(
|
||||
input_tensor.unsqueeze(0), total_pixels=total_pixels
|
||||
).squeeze()
|
||||
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
return img
|
||||
|
||||
|
||||
def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
||||
"""Converts a PIL Image to a BytesIO object."""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
img_byte_arr = io.BytesIO()
|
||||
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
|
||||
pil_format = mime_type.split("/")[-1].upper()
|
||||
if pil_format == "JPG":
|
||||
pil_format = "JPEG"
|
||||
img.save(img_byte_arr, format=pil_format)
|
||||
img_byte_arr.seek(0)
|
||||
return img_byte_arr
|
||||
|
||||
|
||||
def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
name: Optional[str] = None,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> BytesIO:
|
||||
"""Converts a torch.Tensor image to a named BytesIO object.
|
||||
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
name: Optional filename for the BytesIO object.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Named BytesIO object containing the image data.
|
||||
"""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
pil_image = _tensor_to_pil(image, total_pixels=total_pixels)
|
||||
img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_binary.name = (
|
||||
f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||
)
|
||||
return img_binary
|
||||
|
||||
|
||||
def tensor_to_base64_string(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Base64 encoded string of the image.
|
||||
"""
|
||||
pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels)
|
||||
img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
# Encode bytes to base64 string
|
||||
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
|
||||
return base64_encoded_string
|
||||
|
||||
|
||||
def tensor_to_data_uri(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Converts a tensor image to a Data URI string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
|
||||
|
||||
Returns:
|
||||
Data URI string (e.g., 'data:image/png;base64,...').
|
||||
"""
|
||||
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
def upload_file_to_comfyapi(
|
||||
file_bytes_io: BytesIO,
|
||||
filename: str,
|
||||
upload_mime_type: str,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single file to ComfyUI API and returns its download URL.
|
||||
|
||||
Args:
|
||||
file_bytes_io: BytesIO object containing the file data.
|
||||
filename: The filename of the file.
|
||||
upload_mime_type: MIME type of the file.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded file.
|
||||
"""
|
||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/customers/storage",
|
||||
method=HttpMethod.POST,
|
||||
request_model=UploadRequest,
|
||||
response_model=UploadResponse,
|
||||
),
|
||||
request=request_object,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
response: UploadResponse = operation.execute()
|
||||
upload_response = ApiClient.upload_file(
|
||||
response.upload_url, file_bytes_io, content_type=upload_mime_type
|
||||
)
|
||||
upload_response.raise_for_status()
|
||||
|
||||
return response.download_url
|
||||
|
||||
|
||||
def upload_video_to_comfyapi(
|
||||
video: VideoInput,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
codec: VideoCodec = VideoCodec.H264,
|
||||
max_duration: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single video to ComfyUI API and returns its download URL.
|
||||
Uses the specified container and codec for saving the video before upload.
|
||||
|
||||
Args:
|
||||
video: VideoInput object (Comfy VIDEO type).
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
container: The video container format to use (default: MP4).
|
||||
codec: The video codec to use (default: H264).
|
||||
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded video file.
|
||||
"""
|
||||
if max_duration is not None:
|
||||
try:
|
||||
actual_duration = video.duration_seconds
|
||||
if actual_duration is not None and actual_duration > max_duration:
|
||||
raise ValueError(
|
||||
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting video duration: {e}")
|
||||
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||
|
||||
upload_mime_type = f"video/{container.value.lower()}"
|
||||
filename = f"uploaded_video.{container.value.lower()}"
|
||||
|
||||
# Convert VideoInput to BytesIO using specified container/codec
|
||||
video_bytes_io = io.BytesIO()
|
||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return upload_file_to_comfyapi(
|
||||
video_bytes_io, filename, upload_mime_type, auth_kwargs
|
||||
)
|
||||
|
||||
|
||||
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
||||
"""
|
||||
Prepares audio waveform for av library by converting to a contiguous numpy array.
|
||||
|
||||
Args:
|
||||
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
|
||||
|
||||
Returns:
|
||||
Contiguous numpy array of the audio waveform. If the audio was batched,
|
||||
the first item is taken.
|
||||
"""
|
||||
if waveform.ndim != 3 or waveform.shape[0] != 1:
|
||||
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
|
||||
|
||||
# If batch is > 1, take first item
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform[0]
|
||||
|
||||
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
|
||||
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
|
||||
if audio_data_np.dtype != np.float32:
|
||||
audio_data_np = audio_data_np.astype(np.float32)
|
||||
|
||||
return audio_data_np
|
||||
|
||||
|
||||
def audio_ndarray_to_bytesio(
|
||||
audio_data_np: np.ndarray,
|
||||
sample_rate: int,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
) -> BytesIO:
|
||||
"""
|
||||
Encodes a numpy array of audio data into a BytesIO object.
|
||||
"""
|
||||
audio_bytes_io = io.BytesIO()
|
||||
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
|
||||
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
audio_data_np,
|
||||
format="fltp",
|
||||
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
|
||||
# Flush stream
|
||||
for packet in audio_stream.encode(None):
|
||||
output_container.mux(packet)
|
||||
|
||||
audio_bytes_io.seek(0)
|
||||
return audio_bytes_io
|
||||
|
||||
|
||||
def upload_audio_to_comfyapi(
|
||||
audio: AudioInput,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
mime_type: str = "audio/mp4",
|
||||
filename: str = "uploaded_audio.mp4",
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single audio input to ComfyUI API and returns its download URL.
|
||||
Encodes the raw waveform into the specified format before uploading.
|
||||
|
||||
Args:
|
||||
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded audio file.
|
||||
"""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
|
||||
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def upload_images_to_comfyapi(
|
||||
image: torch.Tensor, max_images=8, auth_kwargs: Optional[dict[str,str]] = None, mime_type: Optional[str] = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
To upload multiple images, stack them in the batch dimension first.
|
||||
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
max_images: Maximum number of images to upload.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
mime_type: Optional MIME type for the image.
|
||||
"""
|
||||
# if batch, try to upload each file if max_images is greater than 0
|
||||
idx_image = 0
|
||||
download_urls: list[str] = []
|
||||
is_batch = len(image.shape) > 3
|
||||
batch_length = 1
|
||||
if is_batch:
|
||||
batch_length = image.shape[0]
|
||||
while True:
|
||||
curr_image = image
|
||||
if len(image.shape) > 3:
|
||||
curr_image = image[idx_image]
|
||||
# get BytesIO version of image
|
||||
img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type)
|
||||
# first, request upload/download urls from comfy API
|
||||
if not mime_type:
|
||||
request_object = UploadRequest(file_name=img_binary.name)
|
||||
else:
|
||||
request_object = UploadRequest(
|
||||
file_name=img_binary.name, content_type=mime_type
|
||||
)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/customers/storage",
|
||||
method=HttpMethod.POST,
|
||||
request_model=UploadRequest,
|
||||
response_model=UploadResponse,
|
||||
),
|
||||
request=request_object,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response = operation.execute()
|
||||
|
||||
upload_response = ApiClient.upload_file(
|
||||
response.upload_url, img_binary, content_type=mime_type
|
||||
)
|
||||
# verify success
|
||||
try:
|
||||
upload_response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
raise ValueError(f"Could not upload one or more images: {e}") from e
|
||||
# add download_url to list
|
||||
download_urls.append(response.download_url)
|
||||
|
||||
idx_image += 1
|
||||
# stop uploading additional files if done
|
||||
if is_batch and max_images > 0:
|
||||
if idx_image >= max_images:
|
||||
break
|
||||
if idx_image >= batch_length:
|
||||
break
|
||||
return download_urls
|
||||
|
||||
|
||||
def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor,
|
||||
upscale_method="nearest-exact", crop="disabled",
|
||||
allow_gradient=True, add_channel_dim=False):
|
||||
"""
|
||||
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
|
||||
"""
|
||||
_, H, W, _ = image.shape
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask = mask.movedim(-1,1)
|
||||
mask = common_upscale(mask, width=W, height=H, upscale_method=upscale_method, crop=crop)
|
||||
mask = mask.movedim(1,-1)
|
||||
if not add_channel_dim:
|
||||
mask = mask.squeeze(-1)
|
||||
if not allow_gradient:
|
||||
mask = (mask > 0.5).float()
|
||||
return mask
|
||||
|
||||
|
||||
def validate_string(string: str, strip_whitespace=True, field_name="prompt", min_length=None, max_length=None):
|
||||
if strip_whitespace:
|
||||
string = string.strip()
|
||||
if min_length and len(string) < min_length:
|
||||
raise Exception(f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long.")
|
||||
if max_length and len(string) > max_length:
|
||||
raise Exception(f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long.")
|
||||
if not string:
|
||||
raise Exception(f"Field '{field_name}' cannot be empty.")
|
17
comfy_api_nodes/apis/PixverseController.py
Normal file
17
comfy_api_nodes/apis/PixverseController.py
Normal file
@ -0,0 +1,17 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: filtered-openapi.yaml
|
||||
# timestamp: 2025-04-29T23:44:54+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from . import PixverseDto
|
||||
|
||||
|
||||
class ResponseData(BaseModel):
|
||||
ErrCode: Optional[int] = None
|
||||
ErrMsg: Optional[str] = None
|
||||
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None
|
57
comfy_api_nodes/apis/PixverseDto.py
Normal file
57
comfy_api_nodes/apis/PixverseDto.py
Normal file
@ -0,0 +1,57 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: filtered-openapi.yaml
|
||||
# timestamp: 2025-04-29T23:44:54+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class V2OpenAPII2VResp(BaseModel):
|
||||
video_id: Optional[int] = Field(None, description='Video_id')
|
||||
|
||||
|
||||
class V2OpenAPIT2VReq(BaseModel):
|
||||
aspect_ratio: str = Field(
|
||||
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
|
||||
)
|
||||
duration: int = Field(
|
||||
...,
|
||||
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
|
||||
examples=[5],
|
||||
)
|
||||
model: str = Field(
|
||||
..., description='Model version (only supports v3.5)', examples=['v3.5']
|
||||
)
|
||||
motion_mode: Optional[str] = Field(
|
||||
'normal',
|
||||
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
|
||||
examples=['normal'],
|
||||
)
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None, description='Negative prompt\n', max_length=2048
|
||||
)
|
||||
prompt: str = Field(..., description='Prompt', max_length=2048)
|
||||
quality: str = Field(
|
||||
...,
|
||||
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
|
||||
examples=['540p'],
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
|
||||
style: Optional[str] = Field(
|
||||
None,
|
||||
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
|
||||
examples=['anime'],
|
||||
)
|
||||
template_id: Optional[int] = Field(
|
||||
None,
|
||||
description='Template ID (template_id must be activated before use)',
|
||||
examples=[302325299692608],
|
||||
)
|
||||
water_mark: Optional[bool] = Field(
|
||||
False,
|
||||
description='Watermark (true: add watermark, false: no watermark)',
|
||||
examples=[False],
|
||||
)
|
3829
comfy_api_nodes/apis/__init__.py
Normal file
3829
comfy_api_nodes/apis/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
156
comfy_api_nodes/apis/bfl_api.py
Normal file
156
comfy_api_nodes/apis/bfl_api.py
Normal file
@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, confloat, conint
|
||||
|
||||
|
||||
class BFLOutputFormat(str, Enum):
|
||||
png = 'png'
|
||||
jpeg = 'jpeg'
|
||||
|
||||
|
||||
class BFLFluxExpandImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image')
|
||||
bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image')
|
||||
left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image')
|
||||
right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand')
|
||||
|
||||
|
||||
class BFLFluxFillImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.')
|
||||
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
|
||||
|
||||
|
||||
class BFLFluxCannyImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='Text prompt for image generation')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection')
|
||||
canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection')
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
|
||||
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
|
||||
|
||||
|
||||
class BFLFluxDepthImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='Text prompt for image generation')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
|
||||
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
|
||||
|
||||
|
||||
class BFLFluxProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.')
|
||||
height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
|
||||
# image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
|
||||
# None, description='Blend between the prompt and the image prompt.'
|
||||
# )
|
||||
|
||||
|
||||
class BFLFluxProUltraGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.')
|
||||
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
|
||||
image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
|
||||
None, description='Blend between the prompt and the image prompt.'
|
||||
)
|
||||
|
||||
|
||||
class BFLFluxProGenerateResponse(BaseModel):
|
||||
id: str = Field(..., description='The unique identifier for the generation task.')
|
||||
polling_url: str = Field(..., description='URL to poll for the generation result.')
|
||||
|
||||
|
||||
class BFLStatus(str, Enum):
|
||||
task_not_found = "Task not found"
|
||||
pending = "Pending"
|
||||
request_moderated = "Request Moderated"
|
||||
content_moderated = "Content Moderated"
|
||||
ready = "Ready"
|
||||
error = "Error"
|
||||
|
||||
|
||||
class BFLFluxProStatusResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier for the generation task.")
|
||||
status: BFLStatus = Field(..., description="The status of the task.")
|
||||
result: Optional[Dict[str, Any]] = Field(
|
||||
None, description="The result of the task (null if not completed)."
|
||||
)
|
||||
progress: confloat(ge=0.0, le=1.0) = Field(
|
||||
..., description="The progress of the task (0.0 to 1.0)."
|
||||
)
|
||||
details: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Additional details about the task (null if not available)."
|
||||
)
|
635
comfy_api_nodes/apis/client.py
Normal file
635
comfy_api_nodes/apis/client.py
Normal file
@ -0,0 +1,635 @@
|
||||
"""
|
||||
API Client Framework for api.comfy.org.
|
||||
|
||||
This module provides a flexible framework for making API requests from ComfyUI nodes.
|
||||
It supports both synchronous and asynchronous API operations with proper type validation.
|
||||
|
||||
Key Components:
|
||||
--------------
|
||||
1. ApiClient - Handles HTTP requests with authentication and error handling
|
||||
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
|
||||
3. ApiOperation - Executes a single synchronous API operation
|
||||
|
||||
Usage Examples:
|
||||
--------------
|
||||
|
||||
# Example 1: Synchronous API Operation
|
||||
# ------------------------------------
|
||||
# For a simple API call that returns the result immediately:
|
||||
|
||||
# 1. Create the API client
|
||||
api_client = ApiClient(
|
||||
base_url="https://api.example.com",
|
||||
auth_token="your_auth_token_here",
|
||||
comfy_api_key="your_comfy_api_key_here",
|
||||
timeout=30.0,
|
||||
verify_ssl=True
|
||||
)
|
||||
|
||||
# 2. Define the endpoint
|
||||
user_info_endpoint = ApiEndpoint(
|
||||
path="/v1/users/me",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest, # No request body needed
|
||||
response_model=UserProfile, # Pydantic model for the response
|
||||
query_params=None
|
||||
)
|
||||
|
||||
# 3. Create the request object
|
||||
request = EmptyRequest()
|
||||
|
||||
# 4. Create and execute the operation
|
||||
operation = ApiOperation(
|
||||
endpoint=user_info_endpoint,
|
||||
request=request
|
||||
)
|
||||
user_profile = operation.execute(client=api_client) # Returns immediately with the result
|
||||
|
||||
|
||||
# Example 2: Asynchronous API Operation with Polling
|
||||
# -------------------------------------------------
|
||||
# For an API that starts a task and requires polling for completion:
|
||||
|
||||
# 1. Define the endpoints (initial request and polling)
|
||||
generate_image_endpoint = ApiEndpoint(
|
||||
path="/v1/images/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=ImageGenerationRequest,
|
||||
response_model=TaskCreatedResponse,
|
||||
query_params=None
|
||||
)
|
||||
|
||||
check_task_endpoint = ApiEndpoint(
|
||||
path="/v1/tasks/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=ImageGenerationResult,
|
||||
query_params=None
|
||||
)
|
||||
|
||||
# 2. Create the request object
|
||||
request = ImageGenerationRequest(
|
||||
prompt="a beautiful sunset over mountains",
|
||||
width=1024,
|
||||
height=1024,
|
||||
num_images=1
|
||||
)
|
||||
|
||||
# 3. Create and execute the polling operation
|
||||
operation = PollingOperation(
|
||||
initial_endpoint=generate_image_endpoint,
|
||||
initial_request=request,
|
||||
poll_endpoint=check_task_endpoint,
|
||||
task_id_field="task_id",
|
||||
status_field="status",
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=["failed", "error"]
|
||||
)
|
||||
|
||||
# This will make the initial request and then poll until completion
|
||||
result = operation.execute(client=api_client) # Returns the final ImageGenerationResult when done
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import time
|
||||
import io
|
||||
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable
|
||||
from enum import Enum
|
||||
import json
|
||||
import requests
|
||||
from urllib.parse import urljoin
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy import utils
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
R = TypeVar("R", bound=BaseModel)
|
||||
P = TypeVar("P", bound=BaseModel) # For poll response
|
||||
|
||||
PROGRESS_BAR_MAX = 100
|
||||
|
||||
|
||||
class EmptyRequest(BaseModel):
|
||||
"""Base class for empty request bodies.
|
||||
For GET requests, fields will be sent as query parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UploadRequest(BaseModel):
|
||||
file_name: str = Field(..., description="Filename to upload")
|
||||
content_type: str | None = Field(
|
||||
None,
|
||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||
)
|
||||
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
download_url: str = Field(..., description="URL to GET uploaded file")
|
||||
upload_url: str = Field(..., description="URL to PUT file to upload")
|
||||
|
||||
|
||||
class HttpMethod(str, Enum):
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
PUT = "PUT"
|
||||
DELETE = "DELETE"
|
||||
PATCH = "PATCH"
|
||||
|
||||
|
||||
class ApiClient:
|
||||
"""
|
||||
Client for making HTTP requests to an API with authentication and error handling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
timeout: float = 3600.0,
|
||||
verify_ssl: bool = True,
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
self.timeout = timeout
|
||||
self.verify_ssl = verify_ssl
|
||||
|
||||
def _create_json_payload_args(
|
||||
self,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"json": data,
|
||||
"headers": headers,
|
||||
}
|
||||
|
||||
def _create_form_data_args(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
files: Dict[str, Any],
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
multipart_parser = None,
|
||||
) -> Dict[str, Any]:
|
||||
if headers and "Content-Type" in headers:
|
||||
del headers["Content-Type"]
|
||||
|
||||
if multipart_parser:
|
||||
data = multipart_parser(data)
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"files": files,
|
||||
"headers": headers,
|
||||
}
|
||||
|
||||
def _create_urlencoded_form_data_args(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
headers = headers or {}
|
||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"headers": headers,
|
||||
}
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
"""Get headers for API requests, including authentication if available"""
|
||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||
|
||||
if self.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||||
elif self.comfy_api_key:
|
||||
headers["X-API-KEY"] = self.comfy_api_key
|
||||
|
||||
return headers
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
files: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
content_type: str = "application/json",
|
||||
multipart_parser: Callable = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Make an HTTP request to the API
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: API endpoint path (will be joined with base_url)
|
||||
params: Query parameters
|
||||
data: body data
|
||||
files: Files to upload
|
||||
headers: Additional headers
|
||||
content_type: Content type of the request. Defaults to application/json.
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
|
||||
Raises:
|
||||
requests.RequestException: If the request fails
|
||||
"""
|
||||
url = urljoin(self.base_url, path)
|
||||
self.check_auth(self.auth_token, self.comfy_api_key)
|
||||
# Combine default headers with any provided headers
|
||||
request_headers = self.get_headers()
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
# Let requests handle the content type when files are present.
|
||||
if files:
|
||||
del request_headers["Content-Type"]
|
||||
|
||||
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
|
||||
logging.debug(f"[DEBUG] Files: {files}")
|
||||
logging.debug(f"[DEBUG] Params: {params}")
|
||||
logging.debug(f"[DEBUG] Data: {data}")
|
||||
|
||||
if content_type == "application/x-www-form-urlencoded":
|
||||
payload_args = self._create_urlencoded_form_data_args(data, request_headers)
|
||||
elif content_type == "multipart/form-data":
|
||||
payload_args = self._create_form_data_args(
|
||||
data, files, request_headers, multipart_parser
|
||||
)
|
||||
else:
|
||||
payload_args = self._create_json_payload_args(data, request_headers)
|
||||
|
||||
try:
|
||||
response = requests.request(
|
||||
method=method,
|
||||
url=url,
|
||||
params=params,
|
||||
timeout=self.timeout,
|
||||
verify=self.verify_ssl,
|
||||
**payload_args,
|
||||
)
|
||||
|
||||
# Raise exception for error status codes
|
||||
response.raise_for_status()
|
||||
except requests.ConnectionError:
|
||||
raise Exception(
|
||||
f"Unable to connect to the API server at {self.base_url}. Please check your internet connection or verify the service is available."
|
||||
)
|
||||
|
||||
except requests.Timeout:
|
||||
raise Exception(
|
||||
f"Request timed out after {self.timeout} seconds. The server might be experiencing high load or the operation is taking longer than expected."
|
||||
)
|
||||
|
||||
except requests.HTTPError as e:
|
||||
status_code = e.response.status_code if hasattr(e, "response") else None
|
||||
error_message = f"HTTP Error: {str(e)}"
|
||||
|
||||
# Try to extract detailed error message from JSON response
|
||||
try:
|
||||
if hasattr(e, "response") and e.response.content:
|
||||
error_json = e.response.json()
|
||||
if "error" in error_json and "message" in error_json["error"]:
|
||||
error_message = f"API Error: {error_json['error']['message']}"
|
||||
if "type" in error_json["error"]:
|
||||
error_message += f" (Type: {error_json['error']['type']})"
|
||||
else:
|
||||
error_message = f"API Error: {error_json}"
|
||||
except Exception as json_error:
|
||||
# If we can't parse the JSON, fall back to the original error message
|
||||
logging.debug(
|
||||
f"[DEBUG] Failed to parse error response: {str(json_error)}"
|
||||
)
|
||||
|
||||
logging.debug(f"[DEBUG] API Error: {error_message} (Status: {status_code})")
|
||||
if hasattr(e, "response") and e.response.content:
|
||||
logging.debug(f"[DEBUG] Response content: {e.response.content}")
|
||||
if status_code == 401:
|
||||
error_message = "Unauthorized: Please login first to use this node."
|
||||
if status_code == 402:
|
||||
error_message = "Payment Required: Please add credits to your account to use this node."
|
||||
if status_code == 409:
|
||||
error_message = "There is a problem with your account. Please contact support@comfy.org. "
|
||||
if status_code == 429:
|
||||
error_message = "Rate Limit Exceeded: Please try again later."
|
||||
raise Exception(error_message)
|
||||
|
||||
# Parse and return JSON response
|
||||
if response.content:
|
||||
return response.json()
|
||||
return {}
|
||||
|
||||
def check_auth(self, auth_token, comfy_api_key):
|
||||
"""Verify that an auth token is present or comfy_api_key is present"""
|
||||
if auth_token is None and comfy_api_key is None:
|
||||
raise Exception("Unauthorized: Please login first to use this node.")
|
||||
return auth_token or comfy_api_key
|
||||
|
||||
@staticmethod
|
||||
def upload_file(
|
||||
upload_url: str,
|
||||
file: io.BytesIO | str,
|
||||
content_type: str | None = None,
|
||||
):
|
||||
"""Upload a file to the API. Make sure the file has a filename equal to what the url expects.
|
||||
|
||||
Args:
|
||||
upload_url: The URL to upload to
|
||||
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
|
||||
mime_type: Optional mime type to set for the upload
|
||||
"""
|
||||
headers = {}
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
|
||||
if isinstance(file, io.BytesIO):
|
||||
file.seek(0) # Ensure we're at the start of the file
|
||||
data = file.read()
|
||||
return requests.put(upload_url, data=data, headers=headers)
|
||||
elif isinstance(file, str):
|
||||
with open(file, "rb") as f:
|
||||
data = f.read()
|
||||
return requests.put(upload_url, data=data, headers=headers)
|
||||
|
||||
|
||||
class ApiEndpoint(Generic[T, R]):
|
||||
"""Defines an API endpoint with its request and response types"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
method: HttpMethod,
|
||||
request_model: Type[T],
|
||||
response_model: Type[R],
|
||||
query_params: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Initialize an API endpoint definition.
|
||||
|
||||
Args:
|
||||
path: The URL path for this endpoint, can include placeholders like {id}
|
||||
method: The HTTP method to use (GET, POST, etc.)
|
||||
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
|
||||
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
|
||||
query_params: Optional dictionary of query parameters to include in the request
|
||||
"""
|
||||
self.path = path
|
||||
self.method = method
|
||||
self.request_model = request_model
|
||||
self.response_model = response_model
|
||||
self.query_params = query_params or {}
|
||||
|
||||
|
||||
class SynchronousOperation(Generic[T, R]):
|
||||
"""
|
||||
Represents a single synchronous API operation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: ApiEndpoint[T, R],
|
||||
request: T,
|
||||
files: Optional[Dict[str, Any]] = None,
|
||||
api_base: str | None = None,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
auth_kwargs: Optional[Dict[str,str]] = None,
|
||||
timeout: float = 604800.0,
|
||||
verify_ssl: bool = True,
|
||||
content_type: str = "application/json",
|
||||
multipart_parser: Callable = None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.request = request
|
||||
self.response = None
|
||||
self.error = None
|
||||
self.api_base: str = api_base or args.comfy_api_base
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
if auth_kwargs is not None:
|
||||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
||||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
||||
self.timeout = timeout
|
||||
self.verify_ssl = verify_ssl
|
||||
self.files = files
|
||||
self.content_type = content_type
|
||||
self.multipart_parser = multipart_parser
|
||||
def execute(self, client: Optional[ApiClient] = None) -> R:
|
||||
"""Execute the API operation using the provided client or create one"""
|
||||
try:
|
||||
# Create client if not provided
|
||||
if client is None:
|
||||
client = ApiClient(
|
||||
base_url=self.api_base,
|
||||
auth_token=self.auth_token,
|
||||
comfy_api_key=self.comfy_api_key,
|
||||
timeout=self.timeout,
|
||||
verify_ssl=self.verify_ssl,
|
||||
)
|
||||
|
||||
# Convert request model to dict, but use None for EmptyRequest
|
||||
request_dict = (
|
||||
None
|
||||
if isinstance(self.request, EmptyRequest)
|
||||
else self.request.model_dump(exclude_none=True)
|
||||
)
|
||||
if request_dict:
|
||||
for key, value in request_dict.items():
|
||||
if isinstance(value, Enum):
|
||||
request_dict[key] = value.value
|
||||
|
||||
if request_dict:
|
||||
for key, value in request_dict.items():
|
||||
if isinstance(value, Enum):
|
||||
request_dict[key] = value.value
|
||||
|
||||
# Debug log for request
|
||||
logging.debug(
|
||||
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
|
||||
)
|
||||
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
|
||||
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
|
||||
|
||||
# Make the request
|
||||
resp = client.request(
|
||||
method=self.endpoint.method.value,
|
||||
path=self.endpoint.path,
|
||||
data=request_dict,
|
||||
params=self.endpoint.query_params,
|
||||
files=self.files,
|
||||
content_type=self.content_type,
|
||||
multipart_parser=self.multipart_parser
|
||||
)
|
||||
|
||||
# Debug log for response
|
||||
logging.debug("=" * 50)
|
||||
logging.debug("[DEBUG] RESPONSE DETAILS:")
|
||||
logging.debug("[DEBUG] Status Code: 200 (Success)")
|
||||
logging.debug(f"[DEBUG] Response Body: {json.dumps(resp, indent=2)}")
|
||||
logging.debug("=" * 50)
|
||||
|
||||
# Parse and return the response
|
||||
return self._parse_response(resp)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"[DEBUG] API Exception: {str(e)}")
|
||||
raise Exception(str(e))
|
||||
|
||||
def _parse_response(self, resp):
|
||||
"""Parse response data - can be overridden by subclasses"""
|
||||
# The response is already the complete object, don't extract just the "data" field
|
||||
# as that would lose the outer structure (created timestamp, etc.)
|
||||
|
||||
# Parse response using the provided model
|
||||
self.response = self.endpoint.response_model.model_validate(resp)
|
||||
logging.debug(f"[DEBUG] Parsed Response: {self.response}")
|
||||
return self.response
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Enum for task status values"""
|
||||
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
PENDING = "pending"
|
||||
|
||||
|
||||
class PollingOperation(Generic[T, R]):
|
||||
"""
|
||||
Represents an asynchronous API operation that requires polling for completion.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
poll_endpoint: ApiEndpoint[EmptyRequest, R],
|
||||
completed_statuses: list,
|
||||
failed_statuses: list,
|
||||
status_extractor: Callable[[R], str],
|
||||
progress_extractor: Callable[[R], float] = None,
|
||||
request: Optional[T] = None,
|
||||
api_base: str | None = None,
|
||||
auth_token: Optional[str] = None,
|
||||
comfy_api_key: Optional[str] = None,
|
||||
auth_kwargs: Optional[Dict[str,str]] = None,
|
||||
poll_interval: float = 5.0,
|
||||
):
|
||||
self.poll_endpoint = poll_endpoint
|
||||
self.request = request
|
||||
self.api_base: str = api_base or args.comfy_api_base
|
||||
self.auth_token = auth_token
|
||||
self.comfy_api_key = comfy_api_key
|
||||
if auth_kwargs is not None:
|
||||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
||||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
||||
self.poll_interval = poll_interval
|
||||
|
||||
# Polling configuration
|
||||
self.status_extractor = status_extractor or (
|
||||
lambda x: getattr(x, "status", None)
|
||||
)
|
||||
self.progress_extractor = progress_extractor
|
||||
self.completed_statuses = completed_statuses
|
||||
self.failed_statuses = failed_statuses
|
||||
|
||||
# For storing response data
|
||||
self.final_response = None
|
||||
self.error = None
|
||||
|
||||
def execute(self, client: Optional[ApiClient] = None) -> R:
|
||||
"""Execute the polling operation using the provided client. If failed, raise an exception."""
|
||||
try:
|
||||
if client is None:
|
||||
client = ApiClient(
|
||||
base_url=self.api_base,
|
||||
auth_token=self.auth_token,
|
||||
comfy_api_key=self.comfy_api_key,
|
||||
)
|
||||
return self._poll_until_complete(client)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error during polling: {str(e)}")
|
||||
|
||||
def _check_task_status(self, response: R) -> TaskStatus:
|
||||
"""Check task status using the status extractor function"""
|
||||
try:
|
||||
status = self.status_extractor(response)
|
||||
if status in self.completed_statuses:
|
||||
return TaskStatus.COMPLETED
|
||||
elif status in self.failed_statuses:
|
||||
return TaskStatus.FAILED
|
||||
return TaskStatus.PENDING
|
||||
except Exception as e:
|
||||
logging.error(f"Error extracting status: {e}")
|
||||
return TaskStatus.PENDING
|
||||
|
||||
def _poll_until_complete(self, client: ApiClient) -> R:
|
||||
"""Poll until the task is complete"""
|
||||
poll_count = 0
|
||||
if self.progress_extractor:
|
||||
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
|
||||
|
||||
while True:
|
||||
try:
|
||||
poll_count += 1
|
||||
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
|
||||
|
||||
request_dict = (
|
||||
self.request.model_dump(exclude_none=True)
|
||||
if self.request is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if poll_count == 1:
|
||||
logging.debug(
|
||||
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}"
|
||||
)
|
||||
logging.debug(
|
||||
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}"
|
||||
)
|
||||
|
||||
# Query task status
|
||||
resp = client.request(
|
||||
method=self.poll_endpoint.method.value,
|
||||
path=self.poll_endpoint.path,
|
||||
params=self.poll_endpoint.query_params,
|
||||
data=request_dict,
|
||||
)
|
||||
|
||||
# Parse response
|
||||
response_obj = self.poll_endpoint.response_model.model_validate(resp)
|
||||
# Check if task is complete
|
||||
status = self._check_task_status(response_obj)
|
||||
logging.debug(f"[DEBUG] Task Status: {status}")
|
||||
|
||||
# If progress extractor is provided, extract progress
|
||||
if self.progress_extractor:
|
||||
new_progress = self.progress_extractor(response_obj)
|
||||
if new_progress is not None:
|
||||
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
|
||||
|
||||
if status == TaskStatus.COMPLETED:
|
||||
logging.debug("[DEBUG] Task completed successfully")
|
||||
self.final_response = response_obj
|
||||
if self.progress_extractor:
|
||||
progress.update(100)
|
||||
return self.final_response
|
||||
elif status == TaskStatus.FAILED:
|
||||
message = f"Task failed: {json.dumps(resp)}"
|
||||
logging.error(f"[DEBUG] {message}")
|
||||
raise Exception(message)
|
||||
else:
|
||||
logging.debug("[DEBUG] Task still pending, continuing to poll...")
|
||||
|
||||
# Wait before polling again
|
||||
logging.debug(
|
||||
f"[DEBUG] Waiting {self.poll_interval} seconds before next poll"
|
||||
)
|
||||
time.sleep(self.poll_interval)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"[DEBUG] Polling error: {str(e)}")
|
||||
raise Exception(f"Error while polling: {str(e)}")
|
253
comfy_api_nodes/apis/luma_api.py
Normal file
253
comfy_api_nodes/apis/luma_api.py
Normal file
@ -0,0 +1,253 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, confloat
|
||||
|
||||
|
||||
|
||||
class LumaIO:
|
||||
LUMA_REF = "LUMA_REF"
|
||||
LUMA_CONCEPTS = "LUMA_CONCEPTS"
|
||||
|
||||
|
||||
class LumaReference:
|
||||
def __init__(self, image: torch.Tensor, weight: float):
|
||||
self.image = image
|
||||
self.weight = weight
|
||||
|
||||
def create_api_model(self, download_url: str):
|
||||
return LumaImageRef(url=download_url, weight=self.weight)
|
||||
|
||||
class LumaReferenceChain:
|
||||
def __init__(self, first_ref: LumaReference=None):
|
||||
self.refs: list[LumaReference] = []
|
||||
if first_ref:
|
||||
self.refs.append(first_ref)
|
||||
|
||||
def add(self, luma_ref: LumaReference=None):
|
||||
self.refs.append(luma_ref)
|
||||
|
||||
def create_api_model(self, download_urls: list[str], max_refs=4):
|
||||
if len(self.refs) == 0:
|
||||
return None
|
||||
api_refs: list[LumaImageRef] = []
|
||||
for ref, url in zip(self.refs, download_urls):
|
||||
api_ref = LumaImageRef(url=url, weight=ref.weight)
|
||||
api_refs.append(api_ref)
|
||||
return api_refs
|
||||
|
||||
def clone(self):
|
||||
c = LumaReferenceChain()
|
||||
for ref in self.refs:
|
||||
c.add(ref)
|
||||
return c
|
||||
|
||||
|
||||
class LumaConcept:
|
||||
def __init__(self, key: str):
|
||||
self.key = key
|
||||
|
||||
|
||||
class LumaConceptChain:
|
||||
def __init__(self, str_list: list[str] = None):
|
||||
self.concepts: list[LumaConcept] = []
|
||||
if str_list is not None:
|
||||
for c in str_list:
|
||||
if c != "None":
|
||||
self.add(LumaConcept(key=c))
|
||||
|
||||
def add(self, concept: LumaConcept):
|
||||
self.concepts.append(concept)
|
||||
|
||||
def create_api_model(self):
|
||||
if len(self.concepts) == 0:
|
||||
return None
|
||||
api_concepts: list[LumaConceptObject] = []
|
||||
for concept in self.concepts:
|
||||
if concept.key == "None":
|
||||
continue
|
||||
api_concepts.append(LumaConceptObject(key=concept.key))
|
||||
if len(api_concepts) == 0:
|
||||
return None
|
||||
return api_concepts
|
||||
|
||||
def clone(self):
|
||||
c = LumaConceptChain()
|
||||
for concept in self.concepts:
|
||||
c.add(concept)
|
||||
return c
|
||||
|
||||
def clone_and_merge(self, other: LumaConceptChain):
|
||||
c = self.clone()
|
||||
for concept in other.concepts:
|
||||
c.add(concept)
|
||||
return c
|
||||
|
||||
|
||||
def get_luma_concepts(include_none=False):
|
||||
concepts = []
|
||||
if include_none:
|
||||
concepts.append("None")
|
||||
return concepts + [
|
||||
"truck_left",
|
||||
"pan_right",
|
||||
"pedestal_down",
|
||||
"low_angle",
|
||||
"pedestal_up",
|
||||
"selfie",
|
||||
"pan_left",
|
||||
"roll_right",
|
||||
"zoom_in",
|
||||
"over_the_shoulder",
|
||||
"orbit_right",
|
||||
"orbit_left",
|
||||
"static",
|
||||
"tiny_planet",
|
||||
"high_angle",
|
||||
"bolt_cam",
|
||||
"dolly_zoom",
|
||||
"overhead",
|
||||
"zoom_out",
|
||||
"handheld",
|
||||
"roll_left",
|
||||
"pov",
|
||||
"aerial_drone",
|
||||
"push_in",
|
||||
"crane_down",
|
||||
"truck_right",
|
||||
"tilt_down",
|
||||
"elevator_doors",
|
||||
"tilt_up",
|
||||
"ground_level",
|
||||
"pull_out",
|
||||
"aerial",
|
||||
"crane_up",
|
||||
"eye_level"
|
||||
]
|
||||
|
||||
|
||||
class LumaImageModel(str, Enum):
|
||||
photon_1 = "photon-1"
|
||||
photon_flash_1 = "photon-flash-1"
|
||||
|
||||
|
||||
class LumaVideoModel(str, Enum):
|
||||
ray_2 = "ray-2"
|
||||
ray_flash_2 = "ray-flash-2"
|
||||
ray_1_6 = "ray-1-6"
|
||||
|
||||
|
||||
class LumaAspectRatio(str, Enum):
|
||||
ratio_1_1 = "1:1"
|
||||
ratio_16_9 = "16:9"
|
||||
ratio_9_16 = "9:16"
|
||||
ratio_4_3 = "4:3"
|
||||
ratio_3_4 = "3:4"
|
||||
ratio_21_9 = "21:9"
|
||||
ratio_9_21 = "9:21"
|
||||
|
||||
|
||||
class LumaVideoOutputResolution(str, Enum):
|
||||
res_540p = "540p"
|
||||
res_720p = "720p"
|
||||
res_1080p = "1080p"
|
||||
res_4k = "4k"
|
||||
|
||||
|
||||
class LumaVideoModelOutputDuration(str, Enum):
|
||||
dur_5s = "5s"
|
||||
dur_9s = "9s"
|
||||
|
||||
|
||||
class LumaGenerationType(str, Enum):
|
||||
video = 'video'
|
||||
image = 'image'
|
||||
|
||||
|
||||
class LumaState(str, Enum):
|
||||
queued = "queued"
|
||||
dreaming = "dreaming"
|
||||
completed = "completed"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class LumaAssets(BaseModel):
|
||||
video: Optional[str] = Field(None, description='The URL of the video')
|
||||
image: Optional[str] = Field(None, description='The URL of the image')
|
||||
progress_video: Optional[str] = Field(None, description='The URL of the progress video')
|
||||
|
||||
|
||||
class LumaImageRef(BaseModel):
|
||||
'''Used for image gen'''
|
||||
url: str = Field(..., description='The URL of the image reference')
|
||||
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
|
||||
|
||||
|
||||
class LumaImageReference(BaseModel):
|
||||
'''Used for video gen'''
|
||||
type: Optional[str] = Field('image', description='Input type, defaults to image')
|
||||
url: str = Field(..., description='The URL of the image')
|
||||
|
||||
|
||||
class LumaModifyImageRef(BaseModel):
|
||||
url: str = Field(..., description='The URL of the image reference')
|
||||
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
|
||||
|
||||
|
||||
class LumaCharacterRef(BaseModel):
|
||||
identity0: LumaImageIdentity = Field(..., description='The image identity object')
|
||||
|
||||
|
||||
class LumaImageIdentity(BaseModel):
|
||||
images: list[str] = Field(..., description='The URLs of the image identity')
|
||||
|
||||
|
||||
class LumaGenerationReference(BaseModel):
|
||||
type: str = Field('generation', description='Input type, defaults to generation')
|
||||
id: str = Field(..., description='The ID of the generation')
|
||||
|
||||
|
||||
class LumaKeyframes(BaseModel):
|
||||
frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
|
||||
frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
|
||||
|
||||
|
||||
class LumaConceptObject(BaseModel):
|
||||
key: str = Field(..., description='Camera Concept name')
|
||||
|
||||
|
||||
class LumaImageGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The prompt of the generation')
|
||||
model: LumaImageModel = Field(LumaImageModel.photon_1, description='The image model used for the generation')
|
||||
aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9, description='The aspect ratio of the generation')
|
||||
image_ref: Optional[list[LumaImageRef]] = Field(None, description='List of image reference objects')
|
||||
style_ref: Optional[list[LumaImageRef]] = Field(None, description='List of style reference objects')
|
||||
character_ref: Optional[LumaCharacterRef] = Field(None, description='The image identity object')
|
||||
modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description='The modify image reference object')
|
||||
|
||||
|
||||
class LumaGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The prompt of the generation')
|
||||
model: LumaVideoModel = Field(LumaVideoModel.ray_2, description='The video model used for the generation')
|
||||
duration: Optional[LumaVideoModelOutputDuration] = Field(None, description='The duration of the generation')
|
||||
aspect_ratio: Optional[LumaAspectRatio] = Field(None, description='The aspect ratio of the generation')
|
||||
resolution: Optional[LumaVideoOutputResolution] = Field(None, description='The resolution of the generation')
|
||||
loop: Optional[bool] = Field(None, description='Whether to loop the video')
|
||||
keyframes: Optional[LumaKeyframes] = Field(None, description='The keyframes of the generation')
|
||||
concepts: Optional[list[LumaConceptObject]] = Field(None, description='Camera Concepts to apply to generation')
|
||||
|
||||
|
||||
class LumaGeneration(BaseModel):
|
||||
id: str = Field(..., description='The ID of the generation')
|
||||
generation_type: LumaGenerationType = Field(..., description='Generation type, image or video')
|
||||
state: LumaState = Field(..., description='The state of the generation')
|
||||
failure_reason: Optional[str] = Field(None, description='The reason for the state of the generation')
|
||||
created_at: str = Field(..., description='The date and time when the generation was created')
|
||||
assets: Optional[LumaAssets] = Field(None, description='The assets of the generation')
|
||||
model: str = Field(..., description='The model used for the generation')
|
||||
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation")
|
146
comfy_api_nodes/apis/pixverse_api.py
Normal file
146
comfy_api_nodes/apis/pixverse_api.py
Normal file
@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
pixverse_templates = {
|
||||
"Microwave": 324641385496960,
|
||||
"Suit Swagger": 328545151283968,
|
||||
"Anything, Robot": 313358700761536,
|
||||
"Subject 3 Fever": 327828816843648,
|
||||
"kiss kiss": 315446315336768,
|
||||
}
|
||||
|
||||
|
||||
class PixverseIO:
|
||||
TEMPLATE = "PIXVERSE_TEMPLATE"
|
||||
|
||||
|
||||
class PixverseStatus(int, Enum):
|
||||
successful = 1
|
||||
generating = 5
|
||||
deleted = 6
|
||||
contents_moderation = 7
|
||||
failed = 8
|
||||
|
||||
|
||||
class PixverseAspectRatio(str, Enum):
|
||||
ratio_16_9 = "16:9"
|
||||
ratio_4_3 = "4:3"
|
||||
ratio_1_1 = "1:1"
|
||||
ratio_3_4 = "3:4"
|
||||
ratio_9_16 = "9:16"
|
||||
|
||||
|
||||
class PixverseQuality(str, Enum):
|
||||
res_360p = "360p"
|
||||
res_540p = "540p"
|
||||
res_720p = "720p"
|
||||
res_1080p = "1080p"
|
||||
|
||||
|
||||
class PixverseDuration(int, Enum):
|
||||
dur_5 = 5
|
||||
dur_8 = 8
|
||||
|
||||
|
||||
class PixverseMotionMode(str, Enum):
|
||||
normal = "normal"
|
||||
fast = "fast"
|
||||
|
||||
|
||||
class PixverseStyle(str, Enum):
|
||||
anime = "anime"
|
||||
animation_3d = "3d_animation"
|
||||
clay = "clay"
|
||||
comic = "comic"
|
||||
cyberpunk = "cyberpunk"
|
||||
|
||||
|
||||
# NOTE: forgoing descriptions for now in return for dev speed
|
||||
class PixverseTextVideoRequest(BaseModel):
|
||||
aspect_ratio: PixverseAspectRatio = Field(...)
|
||||
quality: PixverseQuality = Field(...)
|
||||
duration: PixverseDuration = Field(...)
|
||||
model: Optional[str] = Field("v3.5")
|
||||
motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal)
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
style: Optional[str] = Field(None)
|
||||
template_id: Optional[int] = Field(None)
|
||||
water_mark: Optional[bool] = Field(None)
|
||||
|
||||
|
||||
class PixverseImageVideoRequest(BaseModel):
|
||||
quality: PixverseQuality = Field(...)
|
||||
duration: PixverseDuration = Field(...)
|
||||
img_id: int = Field(...)
|
||||
model: Optional[str] = Field("v3.5")
|
||||
motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal)
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
style: Optional[str] = Field(None)
|
||||
template_id: Optional[int] = Field(None)
|
||||
water_mark: Optional[bool] = Field(None)
|
||||
|
||||
|
||||
class PixverseTransitionVideoRequest(BaseModel):
|
||||
quality: PixverseQuality = Field(...)
|
||||
duration: PixverseDuration = Field(...)
|
||||
first_frame_img: int = Field(...)
|
||||
last_frame_img: int = Field(...)
|
||||
model: Optional[str] = Field("v3.5")
|
||||
motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal)
|
||||
prompt: str = Field(...)
|
||||
# negative_prompt: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
# style: Optional[str] = Field(None)
|
||||
# template_id: Optional[int] = Field(None)
|
||||
# water_mark: Optional[bool] = Field(None)
|
||||
|
||||
|
||||
class PixverseImageUploadResponse(BaseModel):
|
||||
ErrCode: Optional[int] = None
|
||||
ErrMsg: Optional[str] = None
|
||||
Resp: Optional[PixverseImgIdResponseObject] = Field(None, alias='Resp')
|
||||
|
||||
|
||||
class PixverseImgIdResponseObject(BaseModel):
|
||||
img_id: Optional[int] = None
|
||||
|
||||
|
||||
class PixverseVideoResponse(BaseModel):
|
||||
ErrCode: Optional[int] = Field(None)
|
||||
ErrMsg: Optional[str] = Field(None)
|
||||
Resp: Optional[PixverseVideoIdResponseObject] = Field(None)
|
||||
|
||||
|
||||
class PixverseVideoIdResponseObject(BaseModel):
|
||||
video_id: int = Field(..., description='Video_id')
|
||||
|
||||
|
||||
class PixverseGenerationStatusResponse(BaseModel):
|
||||
ErrCode: Optional[int] = Field(None)
|
||||
ErrMsg: Optional[str] = Field(None)
|
||||
Resp: Optional[PixverseGenerationStatusResponseObject] = Field(None)
|
||||
|
||||
|
||||
class PixverseGenerationStatusResponseObject(BaseModel):
|
||||
create_time: Optional[str] = Field(None)
|
||||
id: Optional[int] = Field(None)
|
||||
modify_time: Optional[str] = Field(None)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
outputHeight: Optional[int] = Field(None)
|
||||
outputWidth: Optional[int] = Field(None)
|
||||
prompt: Optional[str] = Field(None)
|
||||
resolution_ratio: Optional[int] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
size: Optional[int] = Field(None)
|
||||
status: Optional[int] = Field(None)
|
||||
style: Optional[str] = Field(None)
|
||||
url: Optional[str] = Field(None)
|
262
comfy_api_nodes/apis/recraft_api.py
Normal file
262
comfy_api_nodes/apis/recraft_api.py
Normal file
@ -0,0 +1,262 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, conint, confloat
|
||||
|
||||
|
||||
class RecraftColor:
|
||||
def __init__(self, r: int, g: int, b: int):
|
||||
self.color = [r, g, b]
|
||||
|
||||
def create_api_model(self):
|
||||
return RecraftColorObject(rgb=self.color)
|
||||
|
||||
|
||||
class RecraftColorChain:
|
||||
def __init__(self):
|
||||
self.colors: list[RecraftColor] = []
|
||||
|
||||
def get_first(self):
|
||||
if len(self.colors) > 0:
|
||||
return self.colors[0]
|
||||
return None
|
||||
|
||||
def add(self, color: RecraftColor):
|
||||
self.colors.append(color)
|
||||
|
||||
def create_api_model(self):
|
||||
if not self.colors:
|
||||
return None
|
||||
colors_api = [x.create_api_model() for x in self.colors]
|
||||
return colors_api
|
||||
|
||||
def clone(self):
|
||||
c = RecraftColorChain()
|
||||
for color in self.colors:
|
||||
c.add(color)
|
||||
return c
|
||||
|
||||
def clone_and_merge(self, other: RecraftColorChain):
|
||||
c = self.clone()
|
||||
for color in other.colors:
|
||||
c.add(color)
|
||||
return c
|
||||
|
||||
|
||||
class RecraftControls:
|
||||
def __init__(self, colors: RecraftColorChain=None, background_color: RecraftColorChain=None,
|
||||
artistic_level: int=None, no_text: bool=None):
|
||||
self.colors = colors
|
||||
self.background_color = background_color
|
||||
self.artistic_level = artistic_level
|
||||
self.no_text = no_text
|
||||
|
||||
def create_api_model(self):
|
||||
if self.colors is None and self.background_color is None and self.artistic_level is None and self.no_text is None:
|
||||
return None
|
||||
colors_api = None
|
||||
background_color_api = None
|
||||
if self.colors:
|
||||
colors_api = self.colors.create_api_model()
|
||||
if self.background_color:
|
||||
first_background = self.background_color.get_first()
|
||||
background_color_api = first_background.create_api_model() if first_background else None
|
||||
|
||||
return RecraftControlsObject(colors=colors_api, background_color=background_color_api,
|
||||
artistic_level=self.artistic_level, no_text=self.no_text)
|
||||
|
||||
|
||||
class RecraftStyle:
|
||||
def __init__(self, style: str=None, substyle: str=None, style_id: str=None):
|
||||
self.style = style
|
||||
if substyle == "None":
|
||||
substyle = None
|
||||
self.substyle = substyle
|
||||
self.style_id = style_id
|
||||
|
||||
|
||||
class RecraftIO:
|
||||
STYLEV3 = "RECRAFT_V3_STYLE"
|
||||
COLOR = "RECRAFT_COLOR"
|
||||
CONTROLS = "RECRAFT_CONTROLS"
|
||||
|
||||
|
||||
class RecraftStyleV3(str, Enum):
|
||||
#any = 'any' NOTE: this does not work for some reason... why?
|
||||
realistic_image = 'realistic_image'
|
||||
digital_illustration = 'digital_illustration'
|
||||
vector_illustration = 'vector_illustration'
|
||||
logo_raster = 'logo_raster'
|
||||
|
||||
|
||||
def get_v3_substyles(style_v3: str, include_none=True) -> list[str]:
|
||||
substyles: list[str] = []
|
||||
if include_none:
|
||||
substyles.append("None")
|
||||
return substyles + dict_recraft_substyles_v3.get(style_v3, [])
|
||||
|
||||
|
||||
dict_recraft_substyles_v3 = {
|
||||
RecraftStyleV3.realistic_image: [
|
||||
"b_and_w",
|
||||
"enterprise",
|
||||
"evening_light",
|
||||
"faded_nostalgia",
|
||||
"forest_life",
|
||||
"hard_flash",
|
||||
"hdr",
|
||||
"motion_blur",
|
||||
"mystic_naturalism",
|
||||
"natural_light",
|
||||
"natural_tones",
|
||||
"organic_calm",
|
||||
"real_life_glow",
|
||||
"retro_realism",
|
||||
"retro_snapshot",
|
||||
"studio_portrait",
|
||||
"urban_drama",
|
||||
"village_realism",
|
||||
"warm_folk"
|
||||
],
|
||||
RecraftStyleV3.digital_illustration: [
|
||||
"2d_art_poster",
|
||||
"2d_art_poster_2",
|
||||
"antiquarian",
|
||||
"bold_fantasy",
|
||||
"child_book",
|
||||
"child_books",
|
||||
"cover",
|
||||
"crosshatch",
|
||||
"digital_engraving",
|
||||
"engraving_color",
|
||||
"expressionism",
|
||||
"freehand_details",
|
||||
"grain",
|
||||
"grain_20",
|
||||
"graphic_intensity",
|
||||
"hand_drawn",
|
||||
"hand_drawn_outline",
|
||||
"handmade_3d",
|
||||
"hard_comics",
|
||||
"infantile_sketch",
|
||||
"long_shadow",
|
||||
"modern_folk",
|
||||
"multicolor",
|
||||
"neon_calm",
|
||||
"noir",
|
||||
"nostalgic_pastel",
|
||||
"outline_details",
|
||||
"pastel_gradient",
|
||||
"pastel_sketch",
|
||||
"pixel_art",
|
||||
"plastic",
|
||||
"pop_art",
|
||||
"pop_renaissance",
|
||||
"seamless",
|
||||
"street_art",
|
||||
"tablet_sketch",
|
||||
"urban_glow",
|
||||
"urban_sketching",
|
||||
"vanilla_dreams",
|
||||
"young_adult_book",
|
||||
"young_adult_book_2"
|
||||
],
|
||||
RecraftStyleV3.vector_illustration: [
|
||||
"bold_stroke",
|
||||
"chemistry",
|
||||
"colored_stencil",
|
||||
"contour_pop_art",
|
||||
"cosmics",
|
||||
"cutout",
|
||||
"depressive",
|
||||
"editorial",
|
||||
"emotional_flat",
|
||||
"engraving",
|
||||
"infographical",
|
||||
"line_art",
|
||||
"line_circuit",
|
||||
"linocut",
|
||||
"marker_outline",
|
||||
"mosaic",
|
||||
"naivector",
|
||||
"roundish_flat",
|
||||
"seamless",
|
||||
"segmented_colors",
|
||||
"sharp_contrast",
|
||||
"thin",
|
||||
"vector_photo",
|
||||
"vivid_shapes"
|
||||
],
|
||||
RecraftStyleV3.logo_raster: [
|
||||
"emblem_graffiti",
|
||||
"emblem_pop_art",
|
||||
"emblem_punk",
|
||||
"emblem_stamp",
|
||||
"emblem_vintage"
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class RecraftModel(str, Enum):
|
||||
recraftv3 = 'recraftv3'
|
||||
recraftv2 = 'recraftv2'
|
||||
|
||||
|
||||
class RecraftImageSize(str, Enum):
|
||||
res_1024x1024 = '1024x1024'
|
||||
res_1365x1024 = '1365x1024'
|
||||
res_1024x1365 = '1024x1365'
|
||||
res_1536x1024 = '1536x1024'
|
||||
res_1024x1536 = '1024x1536'
|
||||
res_1820x1024 = '1820x1024'
|
||||
res_1024x1820 = '1024x1820'
|
||||
res_1024x2048 = '1024x2048'
|
||||
res_2048x1024 = '2048x1024'
|
||||
res_1434x1024 = '1434x1024'
|
||||
res_1024x1434 = '1024x1434'
|
||||
res_1024x1280 = '1024x1280'
|
||||
res_1280x1024 = '1280x1024'
|
||||
res_1024x1707 = '1024x1707'
|
||||
res_1707x1024 = '1707x1024'
|
||||
|
||||
|
||||
class RecraftColorObject(BaseModel):
|
||||
rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model')
|
||||
|
||||
|
||||
class RecraftControlsObject(BaseModel):
|
||||
colors: Optional[list[RecraftColorObject]] = Field(None, description='An array of preferable colors')
|
||||
background_color: Optional[RecraftColorObject] = Field(None, description='Use given color as a desired background color')
|
||||
no_text: Optional[bool] = Field(None, description='Do not embed text layouts')
|
||||
artistic_level: Optional[conint(ge=0, le=5)] = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].')
|
||||
|
||||
|
||||
class RecraftImageGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt describing the image to generate')
|
||||
size: Optional[RecraftImageSize] = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
||||
n: conint(ge=1, le=6) = Field(..., description='The number of images to generate')
|
||||
negative_prompt: Optional[str] = Field(None, description='A text description of undesired elements on an image')
|
||||
model: Optional[RecraftModel] = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
|
||||
style: Optional[str] = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
|
||||
substyle: Optional[str] = Field(None, description='The substyle to apply to the generated image, depending on the style input')
|
||||
controls: Optional[RecraftControlsObject] = Field(None, description='A set of custom parameters to tweak generation process')
|
||||
style_id: Optional[str] = Field(None, description='Use a previously uploaded style as a reference; UUID')
|
||||
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
|
||||
random_seed: Optional[int] = Field(None, description="Seed for video generation")
|
||||
# text_layout
|
||||
|
||||
|
||||
class RecraftReturnedObject(BaseModel):
|
||||
image_id: str = Field(..., description='Unique identifier for the generated image')
|
||||
url: str = Field(..., description='URL to access the generated image')
|
||||
|
||||
|
||||
class RecraftImageGenerationResponse(BaseModel):
|
||||
created: int = Field(..., description='Unix timestamp when the generation was created')
|
||||
credits: int = Field(..., description='Number of credits used for the generation')
|
||||
data: Optional[list[RecraftReturnedObject]] = Field(None, description='Array of generated image information')
|
||||
image: Optional[RecraftReturnedObject] = Field(None, description='Single generated image')
|
127
comfy_api_nodes/apis/stability_api.py
Normal file
127
comfy_api_nodes/apis/stability_api.py
Normal file
@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, confloat
|
||||
|
||||
|
||||
class StabilityFormat(str, Enum):
|
||||
png = 'png'
|
||||
jpeg = 'jpeg'
|
||||
webp = 'webp'
|
||||
|
||||
|
||||
class StabilityAspectRatio(str, Enum):
|
||||
ratio_1_1 = "1:1"
|
||||
ratio_16_9 = "16:9"
|
||||
ratio_9_16 = "9:16"
|
||||
ratio_3_2 = "3:2"
|
||||
ratio_2_3 = "2:3"
|
||||
ratio_5_4 = "5:4"
|
||||
ratio_4_5 = "4:5"
|
||||
ratio_21_9 = "21:9"
|
||||
ratio_9_21 = "9:21"
|
||||
|
||||
|
||||
def get_stability_style_presets(include_none=True):
|
||||
presets = []
|
||||
if include_none:
|
||||
presets.append("None")
|
||||
return presets + [x.value for x in StabilityStylePreset]
|
||||
|
||||
|
||||
class StabilityStylePreset(str, Enum):
|
||||
_3d_model = "3d-model"
|
||||
analog_film = "analog-film"
|
||||
anime = "anime"
|
||||
cinematic = "cinematic"
|
||||
comic_book = "comic-book"
|
||||
digital_art = "digital-art"
|
||||
enhance = "enhance"
|
||||
fantasy_art = "fantasy-art"
|
||||
isometric = "isometric"
|
||||
line_art = "line-art"
|
||||
low_poly = "low-poly"
|
||||
modeling_compound = "modeling-compound"
|
||||
neon_punk = "neon-punk"
|
||||
origami = "origami"
|
||||
photographic = "photographic"
|
||||
pixel_art = "pixel-art"
|
||||
tile_texture = "tile-texture"
|
||||
|
||||
|
||||
class Stability_SD3_5_Model(str, Enum):
|
||||
sd3_5_large = "sd3.5-large"
|
||||
# sd3_5_large_turbo = "sd3.5-large-turbo"
|
||||
sd3_5_medium = "sd3.5-medium"
|
||||
|
||||
|
||||
class Stability_SD3_5_GenerationMode(str, Enum):
|
||||
text_to_image = "text-to-image"
|
||||
image_to_image = "image-to-image"
|
||||
|
||||
|
||||
class StabilityStable3_5Request(BaseModel):
|
||||
model: str = Field(...)
|
||||
mode: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
aspect_ratio: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
style_preset: Optional[str] = Field(None)
|
||||
cfg_scale: float = Field(...)
|
||||
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
|
||||
|
||||
|
||||
class StabilityUpscaleConservativeRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
creativity: Optional[confloat(ge=0.2, le=0.5)] = Field(None)
|
||||
|
||||
|
||||
class StabilityUpscaleCreativeRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
creativity: Optional[confloat(ge=0.1, le=0.5)] = Field(None)
|
||||
style_preset: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class StabilityStableUltraRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
aspect_ratio: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
output_format: Optional[str] = Field(StabilityFormat.png.value)
|
||||
image: Optional[str] = Field(None)
|
||||
style_preset: Optional[str] = Field(None)
|
||||
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
|
||||
|
||||
|
||||
class StabilityStableUltraResponse(BaseModel):
|
||||
image: Optional[str] = Field(None)
|
||||
finish_reason: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class StabilityResultsGetResponse(BaseModel):
|
||||
image: Optional[str] = Field(None)
|
||||
finish_reason: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
id: Optional[str] = Field(None)
|
||||
name: Optional[str] = Field(None)
|
||||
errors: Optional[list[str]] = Field(None)
|
||||
status: Optional[str] = Field(None)
|
||||
result: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class StabilityAsyncResponse(BaseModel):
|
||||
id: Optional[str] = Field(None)
|
10
comfy_api_nodes/canary.py
Normal file
10
comfy_api_nodes/canary.py
Normal file
@ -0,0 +1,10 @@
|
||||
import av
|
||||
|
||||
ver = av.__version__.split(".")
|
||||
if int(ver[0]) < 14:
|
||||
raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.")
|
||||
|
||||
if int(ver[0]) == 14 and int(ver[1]) < 2:
|
||||
raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.")
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
116
comfy_api_nodes/mapper_utils.py
Normal file
116
comfy_api_nodes/mapper_utils.py
Normal file
@ -0,0 +1,116 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic import BaseModel
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, InputTypeOptions
|
||||
|
||||
NodeInput = tuple[IO, InputTypeOptions]
|
||||
|
||||
|
||||
def _create_base_config(field_info: FieldInfo) -> InputTypeOptions:
|
||||
config = {}
|
||||
if hasattr(field_info, "default") and field_info.default is not PydanticUndefined:
|
||||
config["default"] = field_info.default
|
||||
if hasattr(field_info, "description") and field_info.description is not None:
|
||||
config["tooltip"] = field_info.description
|
||||
return config
|
||||
|
||||
|
||||
def _get_number_constraints_config(field_info: FieldInfo) -> dict:
|
||||
config = {}
|
||||
if hasattr(field_info, "metadata"):
|
||||
metadata = field_info.metadata
|
||||
for constraint in metadata:
|
||||
if hasattr(constraint, "ge"):
|
||||
config["min"] = constraint.ge
|
||||
if hasattr(constraint, "le"):
|
||||
config["max"] = constraint.le
|
||||
if hasattr(constraint, "multiple_of"):
|
||||
config["step"] = constraint.multiple_of
|
||||
return config
|
||||
|
||||
|
||||
def _model_field_to_image_input(field_info: FieldInfo, **kwargs) -> NodeInput:
|
||||
return IO.IMAGE, {
|
||||
**_create_base_config(field_info),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _model_field_to_string_input(field_info: FieldInfo, **kwargs) -> NodeInput:
|
||||
return IO.STRING, {
|
||||
**_create_base_config(field_info),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _model_field_to_float_input(field_info: FieldInfo, **kwargs) -> NodeInput:
|
||||
return IO.FLOAT, {
|
||||
**_create_base_config(field_info),
|
||||
**_get_number_constraints_config(field_info),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _model_field_to_int_input(field_info: FieldInfo, **kwargs) -> NodeInput:
|
||||
return IO.INT, {
|
||||
**_create_base_config(field_info),
|
||||
**_get_number_constraints_config(field_info),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _model_field_to_combo_input(
|
||||
field_info: FieldInfo, enum_type: type[Enum] = None, **kwargs
|
||||
) -> NodeInput:
|
||||
combo_config = {}
|
||||
if enum_type is not None:
|
||||
combo_config["options"] = [option.value for option in enum_type]
|
||||
combo_config = {
|
||||
**combo_config,
|
||||
**_create_base_config(field_info),
|
||||
**kwargs,
|
||||
}
|
||||
return IO.COMBO, combo_config
|
||||
|
||||
|
||||
def model_field_to_node_input(
|
||||
input_type: IO, base_model: type[BaseModel], field_name: str, **kwargs
|
||||
) -> NodeInput:
|
||||
"""
|
||||
Maps a field from a Pydantic model to a Comfy node input.
|
||||
|
||||
Args:
|
||||
input_type: The type of the input.
|
||||
base_model: The Pydantic model to map the field from.
|
||||
field_name: The name of the field to map.
|
||||
**kwargs: Additional key/values to include in the input options.
|
||||
|
||||
Note:
|
||||
For combo inputs, pass an `Enum` to the `enum_type` keyword argument to populate the options automatically.
|
||||
|
||||
Example:
|
||||
>>> model_field_to_node_input(IO.STRING, MyModel, "my_field", multiline=True)
|
||||
>>> model_field_to_node_input(IO.COMBO, MyModel, "my_field", enum_type=MyEnum)
|
||||
>>> model_field_to_node_input(IO.FLOAT, MyModel, "my_field", slider=True)
|
||||
"""
|
||||
field_info: FieldInfo = base_model.model_fields[field_name]
|
||||
result: NodeInput
|
||||
|
||||
if input_type == IO.IMAGE:
|
||||
result = _model_field_to_image_input(field_info, **kwargs)
|
||||
elif input_type == IO.STRING:
|
||||
result = _model_field_to_string_input(field_info, **kwargs)
|
||||
elif input_type == IO.FLOAT:
|
||||
result = _model_field_to_float_input(field_info, **kwargs)
|
||||
elif input_type == IO.INT:
|
||||
result = _model_field_to_int_input(field_info, **kwargs)
|
||||
elif input_type == IO.COMBO:
|
||||
result = _model_field_to_combo_input(field_info, **kwargs)
|
||||
else:
|
||||
message = f"Invalid input type: {input_type}"
|
||||
raise ValueError(message)
|
||||
|
||||
return result
|
906
comfy_api_nodes/nodes_bfl.py
Normal file
906
comfy_api_nodes/nodes_bfl.py
Normal file
@ -0,0 +1,906 @@
|
||||
import io
|
||||
from inspect import cleandoc
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api_nodes.apis.bfl_api import (
|
||||
BFLStatus,
|
||||
BFLFluxExpandImageRequest,
|
||||
BFLFluxFillImageRequest,
|
||||
BFLFluxCannyImageRequest,
|
||||
BFLFluxDepthImageRequest,
|
||||
BFLFluxProGenerateRequest,
|
||||
BFLFluxProUltraGenerateRequest,
|
||||
BFLFluxProGenerateResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
validate_aspect_ratio,
|
||||
process_image_response,
|
||||
resize_mask_to_image,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
import base64
|
||||
import time
|
||||
|
||||
|
||||
def convert_mask_to_image(mask: torch.Tensor):
|
||||
"""
|
||||
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
|
||||
"""
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask = torch.cat([mask]*3, dim=-1)
|
||||
return mask
|
||||
|
||||
|
||||
def handle_bfl_synchronous_operation(
|
||||
operation: SynchronousOperation, timeout_bfl_calls=360
|
||||
):
|
||||
response_api: BFLFluxProGenerateResponse = operation.execute()
|
||||
return _poll_until_generated(
|
||||
response_api.polling_url, timeout=timeout_bfl_calls
|
||||
)
|
||||
|
||||
def _poll_until_generated(polling_url: str, timeout=360):
|
||||
# used bfl-comfy-nodes to verify code implementation:
|
||||
# https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main
|
||||
start_time = time.time()
|
||||
retries_404 = 0
|
||||
max_retries_404 = 5
|
||||
retry_404_seconds = 2
|
||||
retry_202_seconds = 2
|
||||
retry_pending_seconds = 1
|
||||
request = requests.Request(method=HttpMethod.GET, url=polling_url)
|
||||
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
|
||||
while True:
|
||||
response = requests.Session().send(request.prepare())
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result["status"] == BFLStatus.ready:
|
||||
img_url = result["result"]["sample"]
|
||||
img_response = requests.get(img_url)
|
||||
return process_image_response(img_response)
|
||||
elif result["status"] in [
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
]:
|
||||
status = result["status"]
|
||||
raise Exception(
|
||||
f"BFL API did not return an image due to: {status}."
|
||||
)
|
||||
elif result["status"] == BFLStatus.error:
|
||||
raise Exception(f"BFL API encountered an error: {result}.")
|
||||
elif result["status"] == BFLStatus.pending:
|
||||
time.sleep(retry_pending_seconds)
|
||||
continue
|
||||
elif response.status_code == 404:
|
||||
if retries_404 < max_retries_404:
|
||||
retries_404 += 1
|
||||
time.sleep(retry_404_seconds)
|
||||
continue
|
||||
raise Exception(
|
||||
f"BFL API could not find task after {max_retries_404} tries."
|
||||
)
|
||||
elif response.status_code == 202:
|
||||
time.sleep(retry_202_seconds)
|
||||
elif time.time() - start_time > timeout:
|
||||
raise Exception(
|
||||
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
|
||||
)
|
||||
else:
|
||||
raise Exception(f"BFL API encountered an error: {response.json()}")
|
||||
|
||||
def convert_image_to_base64(image: torch.Tensor):
|
||||
scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048)
|
||||
# remove batch dimension if present
|
||||
if len(scaled_image.shape) > 3:
|
||||
scaled_image = scaled_image[0]
|
||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
return base64.b64encode(img_byte_arr.getvalue()).decode()
|
||||
|
||||
|
||||
class FluxProUltraImageNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
|
||||
"""
|
||||
|
||||
MINIMUM_RATIO = 1 / 4
|
||||
MAXIMUM_RATIO = 4 / 1
|
||||
MINIMUM_RATIO_STR = "1:4"
|
||||
MAXIMUM_RATIO_STR = "4:1"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"prompt_upsampling": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
"aspect_ratio": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "16:9",
|
||||
"tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.",
|
||||
},
|
||||
),
|
||||
"raw": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "When True, generate less processed, more natural-looking images.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image_prompt": (IO.IMAGE,),
|
||||
"image_prompt_strength": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.1,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
"tooltip": "Blend between the prompt and the image prompt.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, aspect_ratio: str):
|
||||
try:
|
||||
validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=cls.MINIMUM_RATIO,
|
||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
||||
)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
return True
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
prompt_upsampling=False,
|
||||
raw=False,
|
||||
seed=0,
|
||||
image_prompt=None,
|
||||
image_prompt_strength=0.1,
|
||||
**kwargs,
|
||||
):
|
||||
if image_prompt is None:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.1-ultra/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxProUltraGenerateRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxProUltraGenerateRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
seed=seed,
|
||||
aspect_ratio=validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=self.MINIMUM_RATIO,
|
||||
maximum_ratio=self.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=self.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
|
||||
),
|
||||
raw=raw,
|
||||
image_prompt=(
|
||||
image_prompt
|
||||
if image_prompt is None
|
||||
else convert_image_to_base64(image_prompt)
|
||||
),
|
||||
image_prompt_strength=(
|
||||
None if image_prompt is None else round(image_prompt_strength, 2)
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
|
||||
class FluxProImageNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"prompt_upsampling": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
},
|
||||
),
|
||||
"width": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 1024,
|
||||
"min": 256,
|
||||
"max": 1440,
|
||||
"step": 32,
|
||||
},
|
||||
),
|
||||
"height": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 768,
|
||||
"min": 256,
|
||||
"max": 1440,
|
||||
"step": 32,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image_prompt": (IO.IMAGE,),
|
||||
# "image_prompt_strength": (
|
||||
# IO.FLOAT,
|
||||
# {
|
||||
# "default": 0.1,
|
||||
# "min": 0.0,
|
||||
# "max": 1.0,
|
||||
# "step": 0.01,
|
||||
# "tooltip": "Blend between the prompt and the image prompt.",
|
||||
# },
|
||||
# ),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_upsampling,
|
||||
width: int,
|
||||
height: int,
|
||||
seed=0,
|
||||
image_prompt=None,
|
||||
# image_prompt_strength=0.1,
|
||||
**kwargs,
|
||||
):
|
||||
image_prompt = (
|
||||
image_prompt
|
||||
if image_prompt is None
|
||||
else convert_image_to_base64(image_prompt)
|
||||
)
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.1/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxProGenerateRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxProGenerateRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
width=width,
|
||||
height=height,
|
||||
seed=seed,
|
||||
image_prompt=image_prompt,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
class FluxProExpandNode(ComfyNodeABC):
|
||||
"""
|
||||
Outpaints image based on prompt.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"prompt_upsampling": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
},
|
||||
),
|
||||
"top": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2048,
|
||||
"tooltip": "Number of pixels to expand at the top of the image"
|
||||
},
|
||||
),
|
||||
"bottom": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2048,
|
||||
"tooltip": "Number of pixels to expand at the bottom of the image"
|
||||
},
|
||||
),
|
||||
"left": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2048,
|
||||
"tooltip": "Number of pixels to expand at the left side of the image"
|
||||
},
|
||||
),
|
||||
"right": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2048,
|
||||
"tooltip": "Number of pixels to expand at the right side of the image"
|
||||
},
|
||||
),
|
||||
"guidance": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 60,
|
||||
"min": 1.5,
|
||||
"max": 100,
|
||||
"tooltip": "Guidance strength for the image generation process"
|
||||
},
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 50,
|
||||
"min": 15,
|
||||
"max": 50,
|
||||
"tooltip": "Number of steps for the image generation process"
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
prompt_upsampling: bool,
|
||||
top: int,
|
||||
bottom: int,
|
||||
left: int,
|
||||
right: int,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
**kwargs,
|
||||
):
|
||||
image = convert_image_to_base64(image)
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.0-expand/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxExpandImageRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxExpandImageRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
top=top,
|
||||
bottom=bottom,
|
||||
left=left,
|
||||
right=right,
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
image=image,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
|
||||
class FluxProFillNode(ComfyNodeABC):
|
||||
"""
|
||||
Inpaints image based on mask and prompt.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
"mask": (IO.MASK,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"prompt_upsampling": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
},
|
||||
),
|
||||
"guidance": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 60,
|
||||
"min": 1.5,
|
||||
"max": 100,
|
||||
"tooltip": "Guidance strength for the image generation process"
|
||||
},
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 50,
|
||||
"min": 15,
|
||||
"max": 50,
|
||||
"tooltip": "Number of steps for the image generation process"
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
prompt: str,
|
||||
prompt_upsampling: bool,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
**kwargs,
|
||||
):
|
||||
# prepare mask
|
||||
mask = resize_mask_to_image(mask, image)
|
||||
mask = convert_image_to_base64(convert_mask_to_image(mask))
|
||||
# make sure image will have alpha channel removed
|
||||
image = convert_image_to_base64(image[:,:,:,:3])
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.0-fill/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxFillImageRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxFillImageRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
image=image,
|
||||
mask=mask,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
class FluxProCannyNode(ComfyNodeABC):
|
||||
"""
|
||||
Generate image using a control image (canny).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"control_image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"prompt_upsampling": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
},
|
||||
),
|
||||
"canny_low_threshold": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.1,
|
||||
"min": 0.01,
|
||||
"max": 0.99,
|
||||
"step": 0.01,
|
||||
"tooltip": "Low threshold for Canny edge detection; ignored if skip_processing is True"
|
||||
},
|
||||
),
|
||||
"canny_high_threshold": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.4,
|
||||
"min": 0.01,
|
||||
"max": 0.99,
|
||||
"step": 0.01,
|
||||
"tooltip": "High threshold for Canny edge detection; ignored if skip_processing is True"
|
||||
},
|
||||
),
|
||||
"skip_preprocessing": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.",
|
||||
},
|
||||
),
|
||||
"guidance": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 30,
|
||||
"min": 1,
|
||||
"max": 100,
|
||||
"tooltip": "Guidance strength for the image generation process"
|
||||
},
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 50,
|
||||
"min": 15,
|
||||
"max": 50,
|
||||
"tooltip": "Number of steps for the image generation process"
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
control_image: torch.Tensor,
|
||||
prompt: str,
|
||||
prompt_upsampling: bool,
|
||||
canny_low_threshold: float,
|
||||
canny_high_threshold: float,
|
||||
skip_preprocessing: bool,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
**kwargs,
|
||||
):
|
||||
control_image = convert_image_to_base64(control_image[:,:,:,:3])
|
||||
preprocessed_image = None
|
||||
|
||||
# scale canny threshold between 0-500, to match BFL's API
|
||||
def scale_value(value: float, min_val=0, max_val=500):
|
||||
return min_val + value * (max_val - min_val)
|
||||
canny_low_threshold = int(round(scale_value(canny_low_threshold)))
|
||||
canny_high_threshold = int(round(scale_value(canny_high_threshold)))
|
||||
|
||||
|
||||
if skip_preprocessing:
|
||||
preprocessed_image = control_image
|
||||
control_image = None
|
||||
canny_low_threshold = None
|
||||
canny_high_threshold = None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.0-canny/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxCannyImageRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxCannyImageRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
control_image=control_image,
|
||||
canny_low_threshold=canny_low_threshold,
|
||||
canny_high_threshold=canny_high_threshold,
|
||||
preprocessed_image=preprocessed_image,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
class FluxProDepthNode(ComfyNodeABC):
|
||||
"""
|
||||
Generate image using a control image (depth).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"control_image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"prompt_upsampling": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
},
|
||||
),
|
||||
"skip_preprocessing": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.",
|
||||
},
|
||||
),
|
||||
"guidance": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 15,
|
||||
"min": 1,
|
||||
"max": 100,
|
||||
"tooltip": "Guidance strength for the image generation process"
|
||||
},
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 50,
|
||||
"min": 15,
|
||||
"max": 50,
|
||||
"tooltip": "Number of steps for the image generation process"
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/BFL"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
control_image: torch.Tensor,
|
||||
prompt: str,
|
||||
prompt_upsampling: bool,
|
||||
skip_preprocessing: bool,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
**kwargs,
|
||||
):
|
||||
control_image = convert_image_to_base64(control_image[:,:,:,:3])
|
||||
preprocessed_image = None
|
||||
|
||||
if skip_preprocessing:
|
||||
preprocessed_image = control_image
|
||||
control_image = None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.0-depth/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxDepthImageRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxDepthImageRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
control_image=control_image,
|
||||
preprocessed_image=preprocessed_image,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
output_image = handle_bfl_synchronous_operation(operation)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"FluxProUltraImageNode": FluxProUltraImageNode,
|
||||
# "FluxProImageNode": FluxProImageNode,
|
||||
"FluxProExpandNode": FluxProExpandNode,
|
||||
"FluxProFillNode": FluxProFillNode,
|
||||
"FluxProCannyNode": FluxProCannyNode,
|
||||
"FluxProDepthNode": FluxProDepthNode,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image",
|
||||
# "FluxProImageNode": "Flux 1.1 [pro] Image",
|
||||
"FluxProExpandNode": "Flux.1 Expand Image",
|
||||
"FluxProFillNode": "Flux.1 Fill Image",
|
||||
"FluxProCannyNode": "Flux.1 Canny Control Image",
|
||||
"FluxProDepthNode": "Flux.1 Depth Control Image",
|
||||
}
|
779
comfy_api_nodes/nodes_ideogram.py
Normal file
779
comfy_api_nodes/nodes_ideogram.py
Normal file
@ -0,0 +1,779 @@
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
from inspect import cleandoc
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import io
|
||||
import torch
|
||||
from comfy_api_nodes.apis import (
|
||||
IdeogramGenerateRequest,
|
||||
IdeogramGenerateResponse,
|
||||
ImageRequest,
|
||||
IdeogramV3Request,
|
||||
IdeogramV3EditRequest,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_bytesio,
|
||||
bytesio_to_image_tensor,
|
||||
resize_mask_to_image,
|
||||
)
|
||||
|
||||
V1_V1_RES_MAP = {
|
||||
"Auto":"AUTO",
|
||||
"512 x 1536":"RESOLUTION_512_1536",
|
||||
"576 x 1408":"RESOLUTION_576_1408",
|
||||
"576 x 1472":"RESOLUTION_576_1472",
|
||||
"576 x 1536":"RESOLUTION_576_1536",
|
||||
"640 x 1024":"RESOLUTION_640_1024",
|
||||
"640 x 1344":"RESOLUTION_640_1344",
|
||||
"640 x 1408":"RESOLUTION_640_1408",
|
||||
"640 x 1472":"RESOLUTION_640_1472",
|
||||
"640 x 1536":"RESOLUTION_640_1536",
|
||||
"704 x 1152":"RESOLUTION_704_1152",
|
||||
"704 x 1216":"RESOLUTION_704_1216",
|
||||
"704 x 1280":"RESOLUTION_704_1280",
|
||||
"704 x 1344":"RESOLUTION_704_1344",
|
||||
"704 x 1408":"RESOLUTION_704_1408",
|
||||
"704 x 1472":"RESOLUTION_704_1472",
|
||||
"720 x 1280":"RESOLUTION_720_1280",
|
||||
"736 x 1312":"RESOLUTION_736_1312",
|
||||
"768 x 1024":"RESOLUTION_768_1024",
|
||||
"768 x 1088":"RESOLUTION_768_1088",
|
||||
"768 x 1152":"RESOLUTION_768_1152",
|
||||
"768 x 1216":"RESOLUTION_768_1216",
|
||||
"768 x 1232":"RESOLUTION_768_1232",
|
||||
"768 x 1280":"RESOLUTION_768_1280",
|
||||
"768 x 1344":"RESOLUTION_768_1344",
|
||||
"832 x 960":"RESOLUTION_832_960",
|
||||
"832 x 1024":"RESOLUTION_832_1024",
|
||||
"832 x 1088":"RESOLUTION_832_1088",
|
||||
"832 x 1152":"RESOLUTION_832_1152",
|
||||
"832 x 1216":"RESOLUTION_832_1216",
|
||||
"832 x 1248":"RESOLUTION_832_1248",
|
||||
"864 x 1152":"RESOLUTION_864_1152",
|
||||
"896 x 960":"RESOLUTION_896_960",
|
||||
"896 x 1024":"RESOLUTION_896_1024",
|
||||
"896 x 1088":"RESOLUTION_896_1088",
|
||||
"896 x 1120":"RESOLUTION_896_1120",
|
||||
"896 x 1152":"RESOLUTION_896_1152",
|
||||
"960 x 832":"RESOLUTION_960_832",
|
||||
"960 x 896":"RESOLUTION_960_896",
|
||||
"960 x 1024":"RESOLUTION_960_1024",
|
||||
"960 x 1088":"RESOLUTION_960_1088",
|
||||
"1024 x 640":"RESOLUTION_1024_640",
|
||||
"1024 x 768":"RESOLUTION_1024_768",
|
||||
"1024 x 832":"RESOLUTION_1024_832",
|
||||
"1024 x 896":"RESOLUTION_1024_896",
|
||||
"1024 x 960":"RESOLUTION_1024_960",
|
||||
"1024 x 1024":"RESOLUTION_1024_1024",
|
||||
"1088 x 768":"RESOLUTION_1088_768",
|
||||
"1088 x 832":"RESOLUTION_1088_832",
|
||||
"1088 x 896":"RESOLUTION_1088_896",
|
||||
"1088 x 960":"RESOLUTION_1088_960",
|
||||
"1120 x 896":"RESOLUTION_1120_896",
|
||||
"1152 x 704":"RESOLUTION_1152_704",
|
||||
"1152 x 768":"RESOLUTION_1152_768",
|
||||
"1152 x 832":"RESOLUTION_1152_832",
|
||||
"1152 x 864":"RESOLUTION_1152_864",
|
||||
"1152 x 896":"RESOLUTION_1152_896",
|
||||
"1216 x 704":"RESOLUTION_1216_704",
|
||||
"1216 x 768":"RESOLUTION_1216_768",
|
||||
"1216 x 832":"RESOLUTION_1216_832",
|
||||
"1232 x 768":"RESOLUTION_1232_768",
|
||||
"1248 x 832":"RESOLUTION_1248_832",
|
||||
"1280 x 704":"RESOLUTION_1280_704",
|
||||
"1280 x 720":"RESOLUTION_1280_720",
|
||||
"1280 x 768":"RESOLUTION_1280_768",
|
||||
"1280 x 800":"RESOLUTION_1280_800",
|
||||
"1312 x 736":"RESOLUTION_1312_736",
|
||||
"1344 x 640":"RESOLUTION_1344_640",
|
||||
"1344 x 704":"RESOLUTION_1344_704",
|
||||
"1344 x 768":"RESOLUTION_1344_768",
|
||||
"1408 x 576":"RESOLUTION_1408_576",
|
||||
"1408 x 640":"RESOLUTION_1408_640",
|
||||
"1408 x 704":"RESOLUTION_1408_704",
|
||||
"1472 x 576":"RESOLUTION_1472_576",
|
||||
"1472 x 640":"RESOLUTION_1472_640",
|
||||
"1472 x 704":"RESOLUTION_1472_704",
|
||||
"1536 x 512":"RESOLUTION_1536_512",
|
||||
"1536 x 576":"RESOLUTION_1536_576",
|
||||
"1536 x 640":"RESOLUTION_1536_640",
|
||||
}
|
||||
|
||||
V1_V2_RATIO_MAP = {
|
||||
"1:1":"ASPECT_1_1",
|
||||
"4:3":"ASPECT_4_3",
|
||||
"3:4":"ASPECT_3_4",
|
||||
"16:9":"ASPECT_16_9",
|
||||
"9:16":"ASPECT_9_16",
|
||||
"2:1":"ASPECT_2_1",
|
||||
"1:2":"ASPECT_1_2",
|
||||
"3:2":"ASPECT_3_2",
|
||||
"2:3":"ASPECT_2_3",
|
||||
"4:5":"ASPECT_4_5",
|
||||
"5:4":"ASPECT_5_4",
|
||||
}
|
||||
|
||||
V3_RATIO_MAP = {
|
||||
"1:3":"1x3",
|
||||
"3:1":"3x1",
|
||||
"1:2":"1x2",
|
||||
"2:1":"2x1",
|
||||
"9:16":"9x16",
|
||||
"16:9":"16x9",
|
||||
"10:16":"10x16",
|
||||
"16:10":"16x10",
|
||||
"2:3":"2x3",
|
||||
"3:2":"3x2",
|
||||
"3:4":"3x4",
|
||||
"4:3":"4x3",
|
||||
"4:5":"4x5",
|
||||
"5:4":"5x4",
|
||||
"1:1":"1x1",
|
||||
}
|
||||
|
||||
V3_RESOLUTIONS= [
|
||||
"Auto",
|
||||
"512x1536",
|
||||
"576x1408",
|
||||
"576x1472",
|
||||
"576x1536",
|
||||
"640x1344",
|
||||
"640x1408",
|
||||
"640x1472",
|
||||
"640x1536",
|
||||
"704x1152",
|
||||
"704x1216",
|
||||
"704x1280",
|
||||
"704x1344",
|
||||
"704x1408",
|
||||
"704x1472",
|
||||
"736x1312",
|
||||
"768x1088",
|
||||
"768x1216",
|
||||
"768x1280",
|
||||
"768x1344",
|
||||
"800x1280",
|
||||
"832x960",
|
||||
"832x1024",
|
||||
"832x1088",
|
||||
"832x1152",
|
||||
"832x1216",
|
||||
"832x1248",
|
||||
"864x1152",
|
||||
"896x960",
|
||||
"896x1024",
|
||||
"896x1088",
|
||||
"896x1120",
|
||||
"896x1152",
|
||||
"960x832",
|
||||
"960x896",
|
||||
"960x1024",
|
||||
"960x1088",
|
||||
"1024x832",
|
||||
"1024x896",
|
||||
"1024x960",
|
||||
"1024x1024",
|
||||
"1088x768",
|
||||
"1088x832",
|
||||
"1088x896",
|
||||
"1088x960",
|
||||
"1120x896",
|
||||
"1152x704",
|
||||
"1152x832",
|
||||
"1152x864",
|
||||
"1152x896",
|
||||
"1216x704",
|
||||
"1216x768",
|
||||
"1216x832",
|
||||
"1248x832",
|
||||
"1280x704",
|
||||
"1280x768",
|
||||
"1280x800",
|
||||
"1312x736",
|
||||
"1344x640",
|
||||
"1344x704",
|
||||
"1344x768",
|
||||
"1408x576",
|
||||
"1408x640",
|
||||
"1408x704",
|
||||
"1472x576",
|
||||
"1472x640",
|
||||
"1472x704",
|
||||
"1536x512",
|
||||
"1536x576",
|
||||
"1536x640"
|
||||
]
|
||||
|
||||
def download_and_process_images(image_urls):
|
||||
"""Helper function to download and process multiple images from URLs"""
|
||||
|
||||
# Initialize list to store image tensors
|
||||
image_tensors = []
|
||||
|
||||
for image_url in image_urls:
|
||||
# Using functions from apinode_utils.py to handle downloading and processing
|
||||
image_bytesio = download_url_to_bytesio(image_url) # Download image content to BytesIO
|
||||
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
||||
image_tensors.append(img_tensor)
|
||||
|
||||
# Stack tensors to match (N, width, height, channels)
|
||||
if image_tensors:
|
||||
stacked_tensors = torch.cat(image_tensors, dim=0)
|
||||
else:
|
||||
raise Exception("No valid images were processed")
|
||||
|
||||
return stacked_tensors
|
||||
|
||||
|
||||
class IdeogramV1(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using the Ideogram V1 model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"turbo": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
}
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V1_V2_RATIO_MAP.keys()),
|
||||
"default": "1:1",
|
||||
"tooltip": "The aspect ratio for image generation.",
|
||||
},
|
||||
),
|
||||
"magic_prompt_option": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "ON", "OFF"],
|
||||
"default": "AUTO",
|
||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"step": 1,
|
||||
"control_after_generate": True,
|
||||
"display": "number",
|
||||
},
|
||||
),
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Description of what to exclude from the image",
|
||||
},
|
||||
),
|
||||
"num_images": (
|
||||
IO.INT,
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram/v1"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt,
|
||||
turbo=False,
|
||||
aspect_ratio="1:1",
|
||||
magic_prompt_option="AUTO",
|
||||
seed=0,
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
**kwargs,
|
||||
):
|
||||
# Determine the model based on turbo setting
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
model = "V_1_TURBO" if turbo else "V_1"
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/ideogram/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramGenerateRequest,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=IdeogramGenerateRequest(
|
||||
image_request=ImageRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
num_images=num_images,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
|
||||
magic_prompt_option=(
|
||||
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
||||
),
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
return (download_and_process_images(image_urls),)
|
||||
|
||||
|
||||
class IdeogramV2(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using the Ideogram V2 model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"turbo": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||
}
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V1_V2_RATIO_MAP.keys()),
|
||||
"default": "1:1",
|
||||
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
||||
},
|
||||
),
|
||||
"resolution": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V1_V1_RES_MAP.keys()),
|
||||
"default": "Auto",
|
||||
"tooltip": "The resolution for image generation. If not set to AUTO, this overrides the aspect_ratio setting.",
|
||||
},
|
||||
),
|
||||
"magic_prompt_option": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "ON", "OFF"],
|
||||
"default": "AUTO",
|
||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"step": 1,
|
||||
"control_after_generate": True,
|
||||
"display": "number",
|
||||
},
|
||||
),
|
||||
"style_type": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
||||
"default": "NONE",
|
||||
"tooltip": "Style type for generation (V2 only)",
|
||||
},
|
||||
),
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Description of what to exclude from the image",
|
||||
},
|
||||
),
|
||||
"num_images": (
|
||||
IO.INT,
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
),
|
||||
#"color_palette": (
|
||||
# IO.STRING,
|
||||
# {
|
||||
# "multiline": False,
|
||||
# "default": "",
|
||||
# "tooltip": "Color palette preset name or hex colors with weights",
|
||||
# },
|
||||
#),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram/v2"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt,
|
||||
turbo=False,
|
||||
aspect_ratio="1:1",
|
||||
resolution="Auto",
|
||||
magic_prompt_option="AUTO",
|
||||
seed=0,
|
||||
style_type="NONE",
|
||||
negative_prompt="",
|
||||
num_images=1,
|
||||
color_palette="",
|
||||
**kwargs,
|
||||
):
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
resolution = V1_V1_RES_MAP.get(resolution, None)
|
||||
# Determine the model based on turbo setting
|
||||
model = "V_2_TURBO" if turbo else "V_2"
|
||||
|
||||
# Handle resolution vs aspect_ratio logic
|
||||
# If resolution is not AUTO, it overrides aspect_ratio
|
||||
final_resolution = None
|
||||
final_aspect_ratio = None
|
||||
|
||||
if resolution != "AUTO":
|
||||
final_resolution = resolution
|
||||
else:
|
||||
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/ideogram/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramGenerateRequest,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=IdeogramGenerateRequest(
|
||||
image_request=ImageRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
num_images=num_images,
|
||||
seed=seed,
|
||||
aspect_ratio=final_aspect_ratio,
|
||||
resolution=final_resolution,
|
||||
magic_prompt_option=(
|
||||
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
||||
),
|
||||
style_type=style_type if style_type != "NONE" else None,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
color_palette=color_palette if color_palette else None,
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
return (download_and_process_images(image_urls),)
|
||||
|
||||
class IdeogramV3(ComfyNodeABC):
|
||||
"""
|
||||
Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation or editing",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image for image editing.",
|
||||
},
|
||||
),
|
||||
"mask": (
|
||||
IO.MASK,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||
},
|
||||
),
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": list(V3_RATIO_MAP.keys()),
|
||||
"default": "1:1",
|
||||
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
|
||||
},
|
||||
),
|
||||
"resolution": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": V3_RESOLUTIONS,
|
||||
"default": "Auto",
|
||||
"tooltip": "The resolution for image generation. If not set to Auto, this overrides the aspect_ratio setting.",
|
||||
},
|
||||
),
|
||||
"magic_prompt_option": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["AUTO", "ON", "OFF"],
|
||||
"default": "AUTO",
|
||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"step": 1,
|
||||
"control_after_generate": True,
|
||||
"display": "number",
|
||||
},
|
||||
),
|
||||
"num_images": (
|
||||
IO.INT,
|
||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
||||
),
|
||||
"rendering_speed": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["BALANCED", "TURBO", "QUALITY"],
|
||||
"default": "BALANCED",
|
||||
"tooltip": "Controls the trade-off between generation speed and quality",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Ideogram/v3"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt,
|
||||
image=None,
|
||||
mask=None,
|
||||
resolution="Auto",
|
||||
aspect_ratio="1:1",
|
||||
magic_prompt_option="AUTO",
|
||||
seed=0,
|
||||
num_images=1,
|
||||
rendering_speed="BALANCED",
|
||||
**kwargs,
|
||||
):
|
||||
# Check if both image and mask are provided for editing mode
|
||||
if image is not None and mask is not None:
|
||||
# Edit mode
|
||||
path = "/proxy/ideogram/ideogram-v3/edit"
|
||||
|
||||
# Process image and mask
|
||||
input_tensor = image.squeeze().cpu()
|
||||
# Resize mask to match image dimension
|
||||
mask = resize_mask_to_image(mask, image, allow_gradient=False)
|
||||
# Invert mask, as Ideogram API will edit black areas instead of white areas (opposite of convention).
|
||||
mask = 1.0 - mask
|
||||
|
||||
# Validate mask dimensions match image
|
||||
if mask.shape[1:] != image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
|
||||
# Process image
|
||||
img_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(img_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
img_binary = img_byte_arr
|
||||
img_binary.name = "image.png"
|
||||
|
||||
# Process mask - white areas will be replaced
|
||||
mask_np = (mask.squeeze().cpu().numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_byte_arr = io.BytesIO()
|
||||
mask_img.save(mask_byte_arr, format="PNG")
|
||||
mask_byte_arr.seek(0)
|
||||
mask_binary = mask_byte_arr
|
||||
mask_binary.name = "mask.png"
|
||||
|
||||
# Create edit request
|
||||
edit_request = IdeogramV3EditRequest(
|
||||
prompt=prompt,
|
||||
rendering_speed=rendering_speed,
|
||||
)
|
||||
|
||||
# Add optional parameters
|
||||
if magic_prompt_option != "AUTO":
|
||||
edit_request.magic_prompt = magic_prompt_option
|
||||
if seed != 0:
|
||||
edit_request.seed = seed
|
||||
if num_images > 1:
|
||||
edit_request.num_images = num_images
|
||||
|
||||
# Execute the operation for edit mode
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramV3EditRequest,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=edit_request,
|
||||
files={
|
||||
"image": img_binary,
|
||||
"mask": mask_binary,
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
elif image is not None or mask is not None:
|
||||
# If only one of image or mask is provided, raise an error
|
||||
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
|
||||
else:
|
||||
# Generation mode
|
||||
path = "/proxy/ideogram/ideogram-v3/generate"
|
||||
|
||||
# Create generation request
|
||||
gen_request = IdeogramV3Request(
|
||||
prompt=prompt,
|
||||
rendering_speed=rendering_speed,
|
||||
)
|
||||
|
||||
# Handle resolution vs aspect ratio
|
||||
if resolution != "Auto":
|
||||
gen_request.resolution = resolution
|
||||
elif aspect_ratio != "1:1":
|
||||
v3_aspect = V3_RATIO_MAP.get(aspect_ratio)
|
||||
if v3_aspect:
|
||||
gen_request.aspect_ratio = v3_aspect
|
||||
|
||||
# Add optional parameters
|
||||
if magic_prompt_option != "AUTO":
|
||||
gen_request.magic_prompt = magic_prompt_option
|
||||
if seed != 0:
|
||||
gen_request.seed = seed
|
||||
if num_images > 1:
|
||||
gen_request.num_images = num_images
|
||||
|
||||
# Execute the operation for generation mode
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramV3Request,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=gen_request,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Execute the operation and process response
|
||||
response = operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
return (download_and_process_images(image_urls),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"IdeogramV1": IdeogramV1,
|
||||
"IdeogramV2": IdeogramV2,
|
||||
"IdeogramV3": IdeogramV3,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"IdeogramV1": "Ideogram V1",
|
||||
"IdeogramV2": "Ideogram V2",
|
||||
"IdeogramV3": "Ideogram V3",
|
||||
}
|
||||
|
1629
comfy_api_nodes/nodes_kling.py
Normal file
1629
comfy_api_nodes/nodes_kling.py
Normal file
File diff suppressed because it is too large
Load Diff
704
comfy_api_nodes/nodes_luma.py
Normal file
704
comfy_api_nodes/nodes_luma.py
Normal file
@ -0,0 +1,704 @@
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis.luma_api import (
|
||||
LumaImageModel,
|
||||
LumaVideoModel,
|
||||
LumaVideoOutputResolution,
|
||||
LumaVideoModelOutputDuration,
|
||||
LumaAspectRatio,
|
||||
LumaState,
|
||||
LumaImageGenerationRequest,
|
||||
LumaGenerationRequest,
|
||||
LumaGeneration,
|
||||
LumaCharacterRef,
|
||||
LumaModifyImageRef,
|
||||
LumaImageIdentity,
|
||||
LumaReference,
|
||||
LumaReferenceChain,
|
||||
LumaImageReference,
|
||||
LumaKeyframes,
|
||||
LumaConceptChain,
|
||||
LumaIO,
|
||||
get_luma_concepts,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
upload_images_to_comfyapi,
|
||||
process_image_response,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class LumaReferenceNode(ComfyNodeABC):
|
||||
"""
|
||||
Holds an image and weight for use with Luma Generate Image node.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (LumaIO.LUMA_REF,)
|
||||
RETURN_NAMES = ("luma_ref",)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "create_luma_reference"
|
||||
CATEGORY = "api node/image/Luma"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "Image to use as reference.",
|
||||
},
|
||||
),
|
||||
"weight": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
"tooltip": "Weight of image reference.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {"luma_ref": (LumaIO.LUMA_REF,)},
|
||||
}
|
||||
|
||||
def create_luma_reference(
|
||||
self, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
|
||||
):
|
||||
if luma_ref is not None:
|
||||
luma_ref = luma_ref.clone()
|
||||
else:
|
||||
luma_ref = LumaReferenceChain()
|
||||
luma_ref.add(LumaReference(image=image, weight=round(weight, 2)))
|
||||
return (luma_ref,)
|
||||
|
||||
|
||||
class LumaConceptsNode(ComfyNodeABC):
|
||||
"""
|
||||
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (LumaIO.LUMA_CONCEPTS,)
|
||||
RETURN_NAMES = ("luma_concepts",)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "create_concepts"
|
||||
CATEGORY = "api node/video/Luma"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"concept1": (get_luma_concepts(include_none=True),),
|
||||
"concept2": (get_luma_concepts(include_none=True),),
|
||||
"concept3": (get_luma_concepts(include_none=True),),
|
||||
"concept4": (get_luma_concepts(include_none=True),),
|
||||
},
|
||||
"optional": {
|
||||
"luma_concepts": (
|
||||
LumaIO.LUMA_CONCEPTS,
|
||||
{
|
||||
"tooltip": "Optional Camera Concepts to add to the ones chosen here."
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
def create_concepts(
|
||||
self,
|
||||
concept1: str,
|
||||
concept2: str,
|
||||
concept3: str,
|
||||
concept4: str,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
):
|
||||
chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4])
|
||||
if luma_concepts is not None:
|
||||
chain = luma_concepts.clone_and_merge(chain)
|
||||
return (chain,)
|
||||
|
||||
|
||||
class LumaImageGenerationNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Luma"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"model": ([model.value for model in LumaImageModel],),
|
||||
"aspect_ratio": (
|
||||
[ratio.value for ratio in LumaAspectRatio],
|
||||
{
|
||||
"default": LumaAspectRatio.ratio_16_9,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
},
|
||||
),
|
||||
"style_image_weight": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
"tooltip": "Weight of style image. Ignored if no style_image provided.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image_luma_ref": (
|
||||
LumaIO.LUMA_REF,
|
||||
{
|
||||
"tooltip": "Luma Reference node connection to influence generation with input images; up to 4 images can be considered."
|
||||
},
|
||||
),
|
||||
"style_image": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Style reference image; only 1 image will be used."},
|
||||
),
|
||||
"character_image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "Character reference images; can be a batch of multiple, up to 4 images can be considered."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
aspect_ratio: str,
|
||||
seed,
|
||||
style_image_weight: float,
|
||||
image_luma_ref: LumaReferenceChain = None,
|
||||
style_image: torch.Tensor = None,
|
||||
character_image: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=True, min_length=3)
|
||||
# handle image_luma_ref
|
||||
api_image_ref = None
|
||||
if image_luma_ref is not None:
|
||||
api_image_ref = self._convert_luma_refs(
|
||||
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
|
||||
)
|
||||
# handle style_luma_ref
|
||||
api_style_ref = None
|
||||
if style_image is not None:
|
||||
api_style_ref = self._convert_style_image(
|
||||
style_image, weight=style_image_weight, auth_kwargs=kwargs,
|
||||
)
|
||||
# handle character_ref images
|
||||
character_ref = None
|
||||
if character_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
character_image, max_images=4, auth_kwargs=kwargs,
|
||||
)
|
||||
character_ref = LumaCharacterRef(
|
||||
identity0=LumaImageIdentity(images=download_urls)
|
||||
)
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations/image",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaImageGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
aspect_ratio=aspect_ratio,
|
||||
image_ref=api_image_ref,
|
||||
style_ref=api_style_ref,
|
||||
character_ref=character_ref,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
img_response = requests.get(response_poll.assets.image)
|
||||
img = process_image_response(img_response)
|
||||
return (img,)
|
||||
|
||||
def _convert_luma_refs(
|
||||
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
luma_urls = []
|
||||
ref_count = 0
|
||||
for ref in luma_ref.refs:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
ref.image, max_images=1, auth_kwargs=auth_kwargs
|
||||
)
|
||||
luma_urls.append(download_urls[0])
|
||||
ref_count += 1
|
||||
if ref_count >= max_refs:
|
||||
break
|
||||
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
||||
|
||||
def _convert_style_image(
|
||||
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
chain = LumaReferenceChain(
|
||||
first_ref=LumaReference(image=style_image, weight=weight)
|
||||
)
|
||||
return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
||||
|
||||
|
||||
class LumaImageModifyNode(ComfyNodeABC):
|
||||
"""
|
||||
Modifies images synchronously based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Luma"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the image generation",
|
||||
},
|
||||
),
|
||||
"image_weight": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.1,
|
||||
"min": 0.0,
|
||||
"max": 0.98,
|
||||
"step": 0.01,
|
||||
"tooltip": "Weight of the image; the closer to 1.0, the less the image will be modified.",
|
||||
},
|
||||
),
|
||||
"model": ([model.value for model in LumaImageModel],),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
image_weight: float,
|
||||
seed,
|
||||
**kwargs,
|
||||
):
|
||||
# first, upload image
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=kwargs,
|
||||
)
|
||||
image_url = download_urls[0]
|
||||
# next, make Luma call with download url provided
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations/image",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaImageGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
modify_image_ref=LumaModifyImageRef(
|
||||
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
img_response = requests.get(response_poll.assets.image)
|
||||
img = process_image_response(img_response)
|
||||
return (img,)
|
||||
|
||||
|
||||
class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos synchronously based on prompt and output_size.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/video/Luma"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the video generation",
|
||||
},
|
||||
),
|
||||
"model": ([model.value for model in LumaVideoModel],),
|
||||
"aspect_ratio": (
|
||||
[ratio.value for ratio in LumaAspectRatio],
|
||||
{
|
||||
"default": LumaAspectRatio.ratio_16_9,
|
||||
},
|
||||
),
|
||||
"resolution": (
|
||||
[resolution.value for resolution in LumaVideoOutputResolution],
|
||||
{
|
||||
"default": LumaVideoOutputResolution.res_540p,
|
||||
},
|
||||
),
|
||||
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
|
||||
"loop": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"luma_concepts": (
|
||||
LumaIO.LUMA_CONCEPTS,
|
||||
{
|
||||
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
duration: str,
|
||||
loop: bool,
|
||||
seed,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False, min_length=3)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
resolution=resolution,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=duration,
|
||||
loop=loop,
|
||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.assets.video)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
|
||||
|
||||
class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos synchronously based on prompt, input images, and output_size.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/video/Luma"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the video generation",
|
||||
},
|
||||
),
|
||||
"model": ([model.value for model in LumaVideoModel],),
|
||||
# "aspect_ratio": ([ratio.value for ratio in LumaAspectRatio], {
|
||||
# "default": LumaAspectRatio.ratio_16_9,
|
||||
# }),
|
||||
"resolution": (
|
||||
[resolution.value for resolution in LumaVideoOutputResolution],
|
||||
{
|
||||
"default": LumaVideoOutputResolution.res_540p,
|
||||
},
|
||||
),
|
||||
"duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
|
||||
"loop": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": False,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"first_image": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "First frame of generated video."},
|
||||
),
|
||||
"last_image": (IO.IMAGE, {"tooltip": "Last frame of generated video."}),
|
||||
"luma_concepts": (
|
||||
LumaIO.LUMA_CONCEPTS,
|
||||
{
|
||||
"tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
resolution: str,
|
||||
duration: str,
|
||||
loop: bool,
|
||||
seed,
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
**kwargs,
|
||||
):
|
||||
if first_image is None and last_image is None:
|
||||
raise Exception(
|
||||
"At least one of first_image and last_image requires an input."
|
||||
)
|
||||
keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
loop=loop,
|
||||
keyframes=keyframes,
|
||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
status_extractor=lambda x: x.state,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.assets.video)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
|
||||
def _convert_to_keyframes(
|
||||
self,
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
):
|
||||
if first_image is None and last_image is None:
|
||||
return None
|
||||
frame0 = None
|
||||
frame1 = None
|
||||
if first_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
first_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
||||
if last_image is not None:
|
||||
download_urls = upload_images_to_comfyapi(
|
||||
last_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
||||
return LumaKeyframes(frame0=frame0, frame1=frame1)
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LumaImageNode": LumaImageGenerationNode,
|
||||
"LumaImageModifyNode": LumaImageModifyNode,
|
||||
"LumaVideoNode": LumaTextToVideoGenerationNode,
|
||||
"LumaImageToVideoNode": LumaImageToVideoGenerationNode,
|
||||
"LumaReferenceNode": LumaReferenceNode,
|
||||
"LumaConceptsNode": LumaConceptsNode,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LumaImageNode": "Luma Text to Image",
|
||||
"LumaImageModifyNode": "Luma Image to Image",
|
||||
"LumaVideoNode": "Luma Text to Video",
|
||||
"LumaImageToVideoNode": "Luma Image to Video",
|
||||
"LumaReferenceNode": "Luma Reference",
|
||||
"LumaConceptsNode": "Luma Concepts",
|
||||
}
|
309
comfy_api_nodes/nodes_minimax.py
Normal file
309
comfy_api_nodes/nodes_minimax.py
Normal file
@ -0,0 +1,309 @@
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
MinimaxVideoGenerationRequest,
|
||||
MinimaxVideoGenerationResponse,
|
||||
MinimaxFileRetrieveResponse,
|
||||
MinimaxTaskResultResponse,
|
||||
SubjectReferenceItem,
|
||||
Model
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_bytesio,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
import torch
|
||||
import logging
|
||||
|
||||
|
||||
class MinimaxTextToVideoNode:
|
||||
"""
|
||||
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt_text": (
|
||||
"STRING",
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt to guide the video generation",
|
||||
},
|
||||
),
|
||||
"model": (
|
||||
[
|
||||
"T2V-01",
|
||||
"T2V-01-Director",
|
||||
],
|
||||
{
|
||||
"default": "T2V-01",
|
||||
"tooltip": "Model to use for video generation",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
DESCRIPTION = "Generates videos from prompts using MiniMax's API"
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/MiniMax"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def generate_video(
|
||||
self,
|
||||
prompt_text,
|
||||
seed=0,
|
||||
model="T2V-01",
|
||||
image: torch.Tensor=None, # used for ImageToVideo
|
||||
subject: torch.Tensor=None, # used for SubjectToVideo
|
||||
**kwargs,
|
||||
):
|
||||
'''
|
||||
Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments.
|
||||
'''
|
||||
if image is None:
|
||||
validate_string(prompt_text, field_name="prompt_text")
|
||||
# upload image, if passed in
|
||||
image_url = None
|
||||
if image is not None:
|
||||
image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0]
|
||||
|
||||
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
||||
subject_reference = None
|
||||
if subject is not None:
|
||||
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0]
|
||||
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
||||
|
||||
|
||||
video_generate_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/video_generation",
|
||||
method=HttpMethod.POST,
|
||||
request_model=MinimaxVideoGenerationRequest,
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
),
|
||||
request=MinimaxVideoGenerationRequest(
|
||||
model=Model(model),
|
||||
prompt=prompt_text,
|
||||
callback_url=None,
|
||||
first_frame_image=image_url,
|
||||
subject_reference=subject_reference,
|
||||
prompt_optimizer=None,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response = video_generate_operation.execute()
|
||||
|
||||
task_id = response.task_id
|
||||
if not task_id:
|
||||
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
||||
|
||||
video_generate_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/query/video_generation",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxTaskResultResponse,
|
||||
query_params={"task_id": task_id},
|
||||
),
|
||||
completed_statuses=["Success"],
|
||||
failed_statuses=["Fail"],
|
||||
status_extractor=lambda x: x.status.value,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
task_result = video_generate_operation.execute()
|
||||
|
||||
file_id = task_result.file_id
|
||||
if file_id is None:
|
||||
raise Exception("Request was not successful. Missing file ID.")
|
||||
file_retrieve_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/files/retrieve",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxFileRetrieveResponse,
|
||||
query_params={"file_id": int(file_id)},
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
file_result = file_retrieve_operation.execute()
|
||||
|
||||
file_url = file_result.file.download_url
|
||||
if file_url is None:
|
||||
raise Exception(
|
||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||
)
|
||||
logging.info(f"Generated video URL: {file_url}")
|
||||
|
||||
video_io = download_url_to_bytesio(file_url)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {file_url}"
|
||||
logging.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
return (VideoFromFile(video_io),)
|
||||
|
||||
|
||||
class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
|
||||
"""
|
||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "Image to use as first frame of video generation"
|
||||
},
|
||||
),
|
||||
"prompt_text": (
|
||||
"STRING",
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt to guide the video generation",
|
||||
},
|
||||
),
|
||||
"model": (
|
||||
[
|
||||
"I2V-01-Director",
|
||||
"I2V-01",
|
||||
"I2V-01-live",
|
||||
],
|
||||
{
|
||||
"default": "I2V-01",
|
||||
"tooltip": "Model to use for video generation",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API"
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/MiniMax"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
|
||||
class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
|
||||
"""
|
||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"subject": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "Image of subject to reference video generation"
|
||||
},
|
||||
),
|
||||
"prompt_text": (
|
||||
"STRING",
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt to guide the video generation",
|
||||
},
|
||||
),
|
||||
"model": (
|
||||
[
|
||||
"S2V-01",
|
||||
],
|
||||
{
|
||||
"default": "S2V-01",
|
||||
"tooltip": "Model to use for video generation",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API"
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/MiniMax"
|
||||
API_NODE = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"MinimaxTextToVideoNode": MinimaxTextToVideoNode,
|
||||
"MinimaxImageToVideoNode": MinimaxImageToVideoNode,
|
||||
# "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"MinimaxTextToVideoNode": "MiniMax Text to Video",
|
||||
"MinimaxImageToVideoNode": "MiniMax Image to Video",
|
||||
"MinimaxSubjectToVideoNode": "MiniMax Subject to Video",
|
||||
}
|
496
comfy_api_nodes/nodes_openai.py
Normal file
496
comfy_api_nodes/nodes_openai.py
Normal file
@ -0,0 +1,496 @@
|
||||
import io
|
||||
from inspect import cleandoc
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
|
||||
|
||||
from comfy_api_nodes.apis import (
|
||||
OpenAIImageGenerationRequest,
|
||||
OpenAIImageEditRequest,
|
||||
OpenAIImageGenerationResponse,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
validate_and_cast_response,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
class OpenAIDalle2(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for DALL·E",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2**31 - 1,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"control_after_generate": True,
|
||||
"tooltip": "not implemented yet in backend",
|
||||
},
|
||||
),
|
||||
"size": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["256x256", "512x512", "1024x1024"],
|
||||
"default": "1024x1024",
|
||||
"tooltip": "Image size",
|
||||
},
|
||||
),
|
||||
"n": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "How many images to generate",
|
||||
},
|
||||
),
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image for image editing.",
|
||||
},
|
||||
),
|
||||
"mask": (
|
||||
IO.MASK,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/OpenAI"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt,
|
||||
seed=0,
|
||||
image=None,
|
||||
mask=None,
|
||||
n=1,
|
||||
size="1024x1024",
|
||||
**kwargs
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
model = "dall-e-2"
|
||||
path = "/proxy/openai/images/generations"
|
||||
content_type = "application/json"
|
||||
request_class = OpenAIImageGenerationRequest
|
||||
img_binary = None
|
||||
|
||||
if image is not None and mask is not None:
|
||||
path = "/proxy/openai/images/edits"
|
||||
content_type = "multipart/form-data"
|
||||
request_class = OpenAIImageEditRequest
|
||||
|
||||
input_tensor = image.squeeze().cpu()
|
||||
height, width, channels = input_tensor.shape
|
||||
rgba_tensor = torch.ones(height, width, 4, device="cpu")
|
||||
rgba_tensor[:, :, :channels] = input_tensor
|
||||
|
||||
if mask.shape[1:] != image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
rgba_tensor[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||
|
||||
rgba_tensor = downscale_image_tensor(rgba_tensor.unsqueeze(0)).squeeze()
|
||||
|
||||
image_np = (rgba_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
img_binary = img_byte_arr # .getvalue()
|
||||
img_binary.name = "image.png"
|
||||
elif image is not None or mask is not None:
|
||||
raise Exception("Dall-E 2 image editing requires an image AND a mask")
|
||||
|
||||
# Build the operation
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=request_class,
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
),
|
||||
request=request_class(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
size=size,
|
||||
seed=seed,
|
||||
),
|
||||
files=(
|
||||
{
|
||||
"image": img_binary,
|
||||
}
|
||||
if img_binary
|
||||
else None
|
||||
),
|
||||
content_type=content_type,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
class OpenAIDalle3(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for DALL·E",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2**31 - 1,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"control_after_generate": True,
|
||||
"tooltip": "not implemented yet in backend",
|
||||
},
|
||||
),
|
||||
"quality": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["standard", "hd"],
|
||||
"default": "standard",
|
||||
"tooltip": "Image quality",
|
||||
},
|
||||
),
|
||||
"style": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["natural", "vivid"],
|
||||
"default": "natural",
|
||||
"tooltip": "Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.",
|
||||
},
|
||||
),
|
||||
"size": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["1024x1024", "1024x1792", "1792x1024"],
|
||||
"default": "1024x1024",
|
||||
"tooltip": "Image size",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/OpenAI"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt,
|
||||
seed=0,
|
||||
style="natural",
|
||||
quality="standard",
|
||||
size="1024x1024",
|
||||
**kwargs
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
model = "dall-e-3"
|
||||
|
||||
# build the operation
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/openai/images/generations",
|
||||
method=HttpMethod.POST,
|
||||
request_model=OpenAIImageGenerationRequest,
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
),
|
||||
request=OpenAIImageGenerationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
size=size,
|
||||
style=style,
|
||||
seed=seed,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
class OpenAIGPTImage1(ComfyNodeABC):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for GPT Image 1",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2**31 - 1,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"control_after_generate": True,
|
||||
"tooltip": "not implemented yet in backend",
|
||||
},
|
||||
),
|
||||
"quality": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["low", "medium", "high"],
|
||||
"default": "low",
|
||||
"tooltip": "Image quality, affects cost and generation time.",
|
||||
},
|
||||
),
|
||||
"background": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["opaque", "transparent"],
|
||||
"default": "opaque",
|
||||
"tooltip": "Return image with or without background",
|
||||
},
|
||||
),
|
||||
"size": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["auto", "1024x1024", "1024x1536", "1536x1024"],
|
||||
"default": "auto",
|
||||
"tooltip": "Image size",
|
||||
},
|
||||
),
|
||||
"n": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "How many images to generate",
|
||||
},
|
||||
),
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image for image editing.",
|
||||
},
|
||||
),
|
||||
"mask": (
|
||||
IO.MASK,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/OpenAI"
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
API_NODE = True
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt,
|
||||
seed=0,
|
||||
quality="low",
|
||||
background="opaque",
|
||||
image=None,
|
||||
mask=None,
|
||||
n=1,
|
||||
size="1024x1024",
|
||||
**kwargs
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
model = "gpt-image-1"
|
||||
path = "/proxy/openai/images/generations"
|
||||
content_type="application/json"
|
||||
request_class = OpenAIImageGenerationRequest
|
||||
img_binaries = []
|
||||
mask_binary = None
|
||||
files = []
|
||||
|
||||
if image is not None:
|
||||
path = "/proxy/openai/images/edits"
|
||||
request_class = OpenAIImageEditRequest
|
||||
content_type ="multipart/form-data"
|
||||
|
||||
batch_size = image.shape[0]
|
||||
|
||||
for i in range(batch_size):
|
||||
single_image = image[i : i + 1]
|
||||
scaled_image = downscale_image_tensor(single_image).squeeze()
|
||||
|
||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
img_binary = img_byte_arr
|
||||
img_binary.name = f"image_{i}.png"
|
||||
|
||||
img_binaries.append(img_binary)
|
||||
if batch_size == 1:
|
||||
files.append(("image", img_binary))
|
||||
else:
|
||||
files.append(("image[]", img_binary))
|
||||
|
||||
if mask is not None:
|
||||
if image is None:
|
||||
raise Exception("Cannot use a mask without an input image")
|
||||
if image.shape[0] != 1:
|
||||
raise Exception("Cannot use a mask with multiple image")
|
||||
if mask.shape[1:] != image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
batch, height, width = mask.shape
|
||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||
|
||||
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze()
|
||||
|
||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_img_byte_arr = io.BytesIO()
|
||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||
mask_img_byte_arr.seek(0)
|
||||
mask_binary = mask_img_byte_arr
|
||||
mask_binary.name = "mask.png"
|
||||
files.append(("mask", mask_binary))
|
||||
|
||||
# Build the operation
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=request_class,
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
),
|
||||
request=request_class(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=n,
|
||||
seed=seed,
|
||||
size=size,
|
||||
),
|
||||
files=files if files else None,
|
||||
content_type=content_type,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
response = operation.execute()
|
||||
|
||||
img_tensor = validate_and_cast_response(response)
|
||||
return (img_tensor,)
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"OpenAIDalle2": OpenAIDalle2,
|
||||
"OpenAIDalle3": OpenAIDalle3,
|
||||
"OpenAIGPTImage1": OpenAIGPTImage1,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"OpenAIDalle2": "OpenAI DALL·E 2",
|
||||
"OpenAIDalle3": "OpenAI DALL·E 3",
|
||||
"OpenAIGPTImage1": "OpenAI GPT Image 1",
|
||||
}
|
757
comfy_api_nodes/nodes_pika.py
Normal file
757
comfy_api_nodes/nodes_pika.py
Normal file
@ -0,0 +1,757 @@
|
||||
"""
|
||||
Pika x ComfyUI API Nodes
|
||||
|
||||
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from typing import Optional, TypeVar
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
from comfy_api_nodes.apis import (
|
||||
PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||
PikaGenerateResponse,
|
||||
PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||
PikaVideoResponse,
|
||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
IngredientsMode,
|
||||
PikaDurationEnum,
|
||||
PikaResolutionEnum,
|
||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||
Pikaffect,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
tensor_to_bytesio,
|
||||
download_url_to_video_output,
|
||||
)
|
||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||
from comfy_api.input_impl.video_types import VideoInput, VideoContainer, VideoCodec
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
||||
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
||||
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
|
||||
|
||||
PIKA_API_VERSION = "2.2"
|
||||
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
|
||||
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
|
||||
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
|
||||
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
|
||||
|
||||
PATH_VIDEO_GET = "/proxy/pika/videos"
|
||||
|
||||
|
||||
class PikaApiError(Exception):
|
||||
"""Exception for Pika API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def is_valid_video_response(response: PikaVideoResponse) -> bool:
|
||||
"""Check if the video response is valid."""
|
||||
return hasattr(response, "url") and response.url is not None
|
||||
|
||||
|
||||
def is_valid_initial_response(response: PikaGenerateResponse) -> bool:
|
||||
"""Check if the initial response is valid."""
|
||||
return hasattr(response, "video_id") and response.video_id is not None
|
||||
|
||||
|
||||
class PikaNodeBase(ComfyNodeABC):
|
||||
"""Base class for Pika nodes."""
|
||||
|
||||
@classmethod
|
||||
def get_base_inputs_types(
|
||||
cls, request_model
|
||||
) -> dict[str, tuple[IO, InputTypeOptions]]:
|
||||
"""Get the base required inputs types common to all Pika nodes."""
|
||||
return {
|
||||
"prompt_text": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
request_model,
|
||||
"promptText",
|
||||
multiline=True,
|
||||
),
|
||||
"negative_prompt": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
request_model,
|
||||
"negativePrompt",
|
||||
multiline=True,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
request_model,
|
||||
"seed",
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
),
|
||||
"resolution": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
request_model,
|
||||
"resolution",
|
||||
enum_type=PikaResolutionEnum,
|
||||
),
|
||||
"duration": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
request_model,
|
||||
"duration",
|
||||
enum_type=PikaDurationEnum,
|
||||
),
|
||||
}
|
||||
|
||||
CATEGORY = "api node/video/Pika"
|
||||
API_NODE = True
|
||||
FUNCTION = "api_call"
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
|
||||
def poll_for_task_status(
|
||||
self, task_id: str, auth_kwargs: Optional[dict[str,str]] = None
|
||||
) -> PikaGenerateResponse:
|
||||
polling_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_GET}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PikaVideoResponse,
|
||||
),
|
||||
completed_statuses=[
|
||||
"finished",
|
||||
],
|
||||
failed_statuses=["failed", "cancelled"],
|
||||
status_extractor=lambda response: (
|
||||
response.status.value if response.status else None
|
||||
),
|
||||
progress_extractor=lambda response: (
|
||||
response.progress if hasattr(response, "progress") else None
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
return polling_operation.execute()
|
||||
|
||||
def execute_task(
|
||||
self,
|
||||
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
) -> tuple[VideoFromFile]:
|
||||
"""Executes the initial operation then polls for the task status until it is completed.
|
||||
|
||||
Args:
|
||||
initial_operation: The initial operation to execute.
|
||||
auth_kwargs: The authentication token(s) to use for the API call.
|
||||
|
||||
Returns:
|
||||
A tuple containing the video file as a VIDEO output.
|
||||
"""
|
||||
initial_response = initial_operation.execute()
|
||||
if not is_valid_initial_response(initial_response):
|
||||
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
|
||||
logging.error(error_msg)
|
||||
raise PikaApiError(error_msg)
|
||||
|
||||
task_id = initial_response.video_id
|
||||
final_response = self.poll_for_task_status(task_id, auth_kwargs)
|
||||
if not is_valid_video_response(final_response):
|
||||
error_msg = (
|
||||
f"Pika task {task_id} succeeded but no video data found in response."
|
||||
)
|
||||
logging.error(error_msg)
|
||||
raise PikaApiError(error_msg)
|
||||
|
||||
video_url = str(final_response.url)
|
||||
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
||||
|
||||
return (download_url_to_video_output(video_url),)
|
||||
|
||||
|
||||
class PikaImageToVideoV2_2(PikaNodeBase):
|
||||
"""Pika 2.2 Image to Video Node."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "The image to convert to video"},
|
||||
),
|
||||
**cls.get_base_inputs_types(PikaBodyGenerate22I2vGenerate22I2vPost),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video."
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
**kwargs
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Convert image to BytesIO
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
|
||||
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
|
||||
|
||||
# Prepare non-file data
|
||||
pika_request_data = PikaBodyGenerate22I2vGenerate22I2vPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs)
|
||||
|
||||
|
||||
class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||
"""Pika Text2Video v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
**cls.get_base_inputs_types(PikaBodyGenerate22T2vGenerate22T2vPost),
|
||||
"aspect_ratio": model_field_to_node_input(
|
||||
IO.FLOAT,
|
||||
PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||
"aspectRatio",
|
||||
step=0.001,
|
||||
min=0.4,
|
||||
max=2.5,
|
||||
default=1.7777777777777777,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video."
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
aspect_ratio: float,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_TEXT_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs)
|
||||
|
||||
|
||||
class PikaScenesV2_2(PikaNodeBase):
|
||||
"""PikaScenes v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
image_ingredient_input = (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "Image that will be used as ingredient to create a video."},
|
||||
)
|
||||
return {
|
||||
"required": {
|
||||
**cls.get_base_inputs_types(
|
||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
),
|
||||
"ingredients_mode": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
"ingredientsMode",
|
||||
enum_type=IngredientsMode,
|
||||
default="creative",
|
||||
),
|
||||
"aspect_ratio": model_field_to_node_input(
|
||||
IO.FLOAT,
|
||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
"aspectRatio",
|
||||
step=0.001,
|
||||
min=0.4,
|
||||
max=2.5,
|
||||
default=1.7777777777777777,
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image_ingredient_1": image_ingredient_input,
|
||||
"image_ingredient_2": image_ingredient_input,
|
||||
"image_ingredient_3": image_ingredient_input,
|
||||
"image_ingredient_4": image_ingredient_input,
|
||||
"image_ingredient_5": image_ingredient_input,
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them."
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
ingredients_mode: str,
|
||||
aspect_ratio: float,
|
||||
image_ingredient_1: Optional[torch.Tensor] = None,
|
||||
image_ingredient_2: Optional[torch.Tensor] = None,
|
||||
image_ingredient_3: Optional[torch.Tensor] = None,
|
||||
image_ingredient_4: Optional[torch.Tensor] = None,
|
||||
image_ingredient_5: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Convert all passed images to BytesIO
|
||||
all_image_bytes_io = []
|
||||
for image in [
|
||||
image_ingredient_1,
|
||||
image_ingredient_2,
|
||||
image_ingredient_3,
|
||||
image_ingredient_4,
|
||||
image_ingredient_5,
|
||||
]:
|
||||
if image is not None:
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
all_image_bytes_io.append(image_bytes_io)
|
||||
|
||||
pika_files = [
|
||||
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
|
||||
for i, image_bytes_io in enumerate(all_image_bytes_io)
|
||||
]
|
||||
|
||||
pika_request_data = PikaBodyGenerate22C2vGenerate22PikascenesPost(
|
||||
ingredientsMode=ingredients_mode,
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKASCENES,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs)
|
||||
|
||||
|
||||
class PikAdditionsNode(PikaNodeBase):
|
||||
"""Pika Pikadditions Node. Add an image into a video."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"video": (IO.VIDEO, {"tooltip": "The video to add an image to."}),
|
||||
"image": (IO.IMAGE, {"tooltip": "The image to add to the video."}),
|
||||
"prompt_text": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
"promptText",
|
||||
multiline=True,
|
||||
),
|
||||
"negative_prompt": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
"negativePrompt",
|
||||
multiline=True,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
"seed",
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you’d like to add to create a seamlessly integrated result."
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
image: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Convert video to BytesIO
|
||||
video_bytes_io = io.BytesIO()
|
||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
# Convert image to BytesIO
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
|
||||
pika_files = [
|
||||
("video", ("video.mp4", video_bytes_io, "video/mp4")),
|
||||
("image", ("image.png", image_bytes_io, "image/png")),
|
||||
]
|
||||
|
||||
# Prepare non-file data
|
||||
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKADDITIONS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs)
|
||||
|
||||
|
||||
class PikaSwapsNode(PikaNodeBase):
|
||||
"""Pika Pikaswaps Node."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"video": (IO.VIDEO, {"tooltip": "The video to swap an object in."}),
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"tooltip": "The image used to replace the masked object in the video."
|
||||
},
|
||||
),
|
||||
"mask": (
|
||||
IO.MASK,
|
||||
{"tooltip": "Use the mask to define areas in the video to replace"},
|
||||
),
|
||||
"prompt_text": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||
"promptText",
|
||||
multiline=True,
|
||||
),
|
||||
"negative_prompt": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||
"negativePrompt",
|
||||
multiline=True,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||
"seed",
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates."
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
video: VideoInput,
|
||||
image: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
# Convert video to BytesIO
|
||||
video_bytes_io = io.BytesIO()
|
||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
# Convert mask to binary mask with three channels
|
||||
mask = torch.round(mask)
|
||||
mask = mask.repeat(1, 3, 1, 1)
|
||||
|
||||
# Convert 3-channel binary mask to BytesIO
|
||||
mask_bytes_io = io.BytesIO()
|
||||
mask_bytes_io.write(mask.numpy().astype(np.uint8))
|
||||
mask_bytes_io.seek(0)
|
||||
|
||||
# Convert image to BytesIO
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
image_bytes_io.seek(0)
|
||||
|
||||
pika_files = [
|
||||
("video", ("video.mp4", video_bytes_io, "video/mp4")),
|
||||
("image", ("image.png", image_bytes_io, "image/png")),
|
||||
("modifyRegionMask", ("mask.png", mask_bytes_io, "image/png")),
|
||||
]
|
||||
|
||||
# Prepare non-file data
|
||||
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKADDITIONS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs)
|
||||
|
||||
|
||||
class PikaffectsNode(PikaNodeBase):
|
||||
"""Pika Pikaffects Node."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
{"tooltip": "The reference image to apply the Pikaffect to."},
|
||||
),
|
||||
"pikaffect": model_field_to_node_input(
|
||||
IO.COMBO,
|
||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
"pikaffect",
|
||||
enum_type=Pikaffect,
|
||||
default="Cake-ify",
|
||||
),
|
||||
"prompt_text": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
"promptText",
|
||||
multiline=True,
|
||||
),
|
||||
"negative_prompt": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
"negativePrompt",
|
||||
multiline=True,
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
"seed",
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear"
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
pikaffect: str,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKAFFECTS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||
pikaffect=pikaffect,
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
),
|
||||
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs)
|
||||
|
||||
|
||||
class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||
"""PikaFrames v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image_start": (IO.IMAGE, {"tooltip": "The first image to combine."}),
|
||||
"image_end": (IO.IMAGE, {"tooltip": "The last image to combine."}),
|
||||
**cls.get_base_inputs_types(
|
||||
PikaBodyGenerate22KeyframeGenerate22PikaframesPost
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them."
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image_start: torch.Tensor,
|
||||
image_end: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
**kwargs,
|
||||
) -> tuple[VideoFromFile]:
|
||||
|
||||
pika_files = [
|
||||
(
|
||||
"keyFrames",
|
||||
("image_start.png", tensor_to_bytesio(image_start), "image/png"),
|
||||
),
|
||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||
]
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKAFRAMES,
|
||||
method=HttpMethod.POST,
|
||||
request_model=PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||
response_model=PikaGenerateResponse,
|
||||
),
|
||||
request=PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
),
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return self.execute_task(initial_operation, auth_kwargs=kwargs)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PikaImageToVideoNode2_2": PikaImageToVideoV2_2,
|
||||
"PikaTextToVideoNode2_2": PikaTextToVideoNodeV2_2,
|
||||
"PikaScenesV2_2": PikaScenesV2_2,
|
||||
"Pikadditions": PikAdditionsNode,
|
||||
"Pikaswaps": PikaSwapsNode,
|
||||
"Pikaffects": PikaffectsNode,
|
||||
"PikaStartEndFrameNode2_2": PikaStartEndFrameNode2_2,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PikaImageToVideoNode2_2": "Pika Image to Video",
|
||||
"PikaTextToVideoNode2_2": "Pika Text to Video",
|
||||
"PikaScenesV2_2": "Pika Scenes (Video Image Composition)",
|
||||
"Pikadditions": "Pikadditions (Video Object Insertion)",
|
||||
"Pikaswaps": "Pika Swaps (Video Object Replacement)",
|
||||
"Pikaffects": "Pikaffects (Video Effects)",
|
||||
"PikaStartEndFrameNode2_2": "Pika Start and End Frame to Video",
|
||||
}
|
492
comfy_api_nodes/nodes_pixverse.py
Normal file
492
comfy_api_nodes/nodes_pixverse.py
Normal file
@ -0,0 +1,492 @@
|
||||
from inspect import cleandoc
|
||||
|
||||
from comfy_api_nodes.apis.pixverse_api import (
|
||||
PixverseTextVideoRequest,
|
||||
PixverseImageVideoRequest,
|
||||
PixverseTransitionVideoRequest,
|
||||
PixverseImageUploadResponse,
|
||||
PixverseVideoResponse,
|
||||
PixverseGenerationStatusResponse,
|
||||
PixverseAspectRatio,
|
||||
PixverseQuality,
|
||||
PixverseDuration,
|
||||
PixverseMotionMode,
|
||||
PixverseStatus,
|
||||
PixverseIO,
|
||||
pixverse_templates,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
tensor_to_bytesio,
|
||||
validate_string,
|
||||
)
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
|
||||
import torch
|
||||
import requests
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
# first, upload image to Pixverse and get image id to use in actual generation call
|
||||
files = {
|
||||
"image": tensor_to_bytesio(image)
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/image/upload",
|
||||
method=HttpMethod.POST,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PixverseImageUploadResponse,
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_upload: PixverseImageUploadResponse = operation.execute()
|
||||
|
||||
if response_upload.Resp is None:
|
||||
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
|
||||
|
||||
return response_upload.Resp.img_id
|
||||
|
||||
|
||||
class PixverseTemplateNode:
|
||||
"""
|
||||
Select template for PixVerse Video generation.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (PixverseIO.TEMPLATE,)
|
||||
RETURN_NAMES = ("pixverse_template",)
|
||||
FUNCTION = "create_template"
|
||||
CATEGORY = "api node/video/PixVerse"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"template": (list(pixverse_templates.keys()), ),
|
||||
}
|
||||
}
|
||||
|
||||
def create_template(self, template: str):
|
||||
template_id = pixverse_templates.get(template, None)
|
||||
if template_id is None:
|
||||
raise Exception(f"Template '{template}' is not recognized.")
|
||||
# just return the integer
|
||||
return (template_id,)
|
||||
|
||||
|
||||
class PixverseTextToVideoNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos synchronously based on prompt and output_size.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/video/PixVerse"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the video generation",
|
||||
},
|
||||
),
|
||||
"aspect_ratio": (
|
||||
[ratio.value for ratio in PixverseAspectRatio],
|
||||
),
|
||||
"quality": (
|
||||
[resolution.value for resolution in PixverseQuality],
|
||||
{
|
||||
"default": PixverseQuality.res_540p,
|
||||
},
|
||||
),
|
||||
"duration_seconds": ([dur.value for dur in PixverseDuration],),
|
||||
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed for video generation.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "An optional text description of undesired elements on an image.",
|
||||
},
|
||||
),
|
||||
"pixverse_template": (
|
||||
PixverseIO.TEMPLATE,
|
||||
{
|
||||
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
|
||||
}
|
||||
)
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
quality: str,
|
||||
duration_seconds: int,
|
||||
motion_mode: str,
|
||||
seed,
|
||||
negative_prompt: str=None,
|
||||
pixverse_template: int=None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
if quality == PixverseQuality.res_1080p:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
duration_seconds = PixverseDuration.dur_5
|
||||
elif duration_seconds != PixverseDuration.dur_5:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/video/text/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=PixverseTextVideoRequest,
|
||||
response_model=PixverseVideoResponse,
|
||||
),
|
||||
request=PixverseTextVideoRequest(
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
quality=quality,
|
||||
duration=duration_seconds,
|
||||
motion_mode=motion_mode,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
template_id=pixverse_template,
|
||||
seed=seed,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
|
||||
|
||||
class PixverseImageToVideoNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos synchronously based on prompt and output_size.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/video/PixVerse"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (
|
||||
IO.IMAGE,
|
||||
),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the video generation",
|
||||
},
|
||||
),
|
||||
"quality": (
|
||||
[resolution.value for resolution in PixverseQuality],
|
||||
{
|
||||
"default": PixverseQuality.res_540p,
|
||||
},
|
||||
),
|
||||
"duration_seconds": ([dur.value for dur in PixverseDuration],),
|
||||
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed for video generation.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "An optional text description of undesired elements on an image.",
|
||||
},
|
||||
),
|
||||
"pixverse_template": (
|
||||
PixverseIO.TEMPLATE,
|
||||
{
|
||||
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
|
||||
}
|
||||
)
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
quality: str,
|
||||
duration_seconds: int,
|
||||
motion_mode: str,
|
||||
seed,
|
||||
negative_prompt: str=None,
|
||||
pixverse_template: int=None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
if quality == PixverseQuality.res_1080p:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
duration_seconds = PixverseDuration.dur_5
|
||||
elif duration_seconds != PixverseDuration.dur_5:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/video/img/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=PixverseImageVideoRequest,
|
||||
response_model=PixverseVideoResponse,
|
||||
),
|
||||
request=PixverseImageVideoRequest(
|
||||
img_id=img_id,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
duration=duration_seconds,
|
||||
motion_mode=motion_mode,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
template_id=pixverse_template,
|
||||
seed=seed,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
|
||||
|
||||
class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos synchronously based on prompt and output_size.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/video/PixVerse"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"first_frame": (
|
||||
IO.IMAGE,
|
||||
),
|
||||
"last_frame": (
|
||||
IO.IMAGE,
|
||||
),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Prompt for the video generation",
|
||||
},
|
||||
),
|
||||
"quality": (
|
||||
[resolution.value for resolution in PixverseQuality],
|
||||
{
|
||||
"default": PixverseQuality.res_540p,
|
||||
},
|
||||
),
|
||||
"duration_seconds": ([dur.value for dur in PixverseDuration],),
|
||||
"motion_mode": ([mode.value for mode in PixverseMotionMode],),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 2147483647,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed for video generation.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "An optional text description of undesired elements on an image.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(
|
||||
self,
|
||||
first_frame: torch.Tensor,
|
||||
last_frame: torch.Tensor,
|
||||
prompt: str,
|
||||
quality: str,
|
||||
duration_seconds: int,
|
||||
motion_mode: str,
|
||||
seed,
|
||||
negative_prompt: str=None,
|
||||
**kwargs,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
|
||||
last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
if quality == PixverseQuality.res_1080p:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
duration_seconds = PixverseDuration.dur_5
|
||||
elif duration_seconds != PixverseDuration.dur_5:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/video/transition/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=PixverseTransitionVideoRequest,
|
||||
response_model=PixverseVideoResponse,
|
||||
),
|
||||
request=PixverseTransitionVideoRequest(
|
||||
first_frame_img=first_frame_id,
|
||||
last_frame_img=last_frame_id,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
duration=duration_seconds,
|
||||
motion_mode=motion_mode,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
seed=seed,
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll = operation.execute()
|
||||
|
||||
vid_response = requests.get(response_poll.Resp.url)
|
||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PixverseTextToVideoNode": PixverseTextToVideoNode,
|
||||
"PixverseImageToVideoNode": PixverseImageToVideoNode,
|
||||
"PixverseTransitionVideoNode": PixverseTransitionVideoNode,
|
||||
"PixverseTemplateNode": PixverseTemplateNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PixverseTextToVideoNode": "PixVerse Text to Video",
|
||||
"PixverseImageToVideoNode": "PixVerse Image to Video",
|
||||
"PixverseTransitionVideoNode": "PixVerse Transition Video",
|
||||
"PixverseTemplateNode": "PixVerse Template",
|
||||
}
|
1117
comfy_api_nodes/nodes_recraft.py
Normal file
1117
comfy_api_nodes/nodes_recraft.py
Normal file
File diff suppressed because it is too large
Load Diff
614
comfy_api_nodes/nodes_stability.py
Normal file
614
comfy_api_nodes/nodes_stability.py
Normal file
@ -0,0 +1,614 @@
|
||||
from inspect import cleandoc
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy_api_nodes.apis.stability_api import (
|
||||
StabilityUpscaleConservativeRequest,
|
||||
StabilityUpscaleCreativeRequest,
|
||||
StabilityAsyncResponse,
|
||||
StabilityResultsGetResponse,
|
||||
StabilityStable3_5Request,
|
||||
StabilityStableUltraRequest,
|
||||
StabilityStableUltraResponse,
|
||||
StabilityAspectRatio,
|
||||
Stability_SD3_5_Model,
|
||||
Stability_SD3_5_GenerationMode,
|
||||
get_stability_style_presets,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
bytesio_to_image_tensor,
|
||||
tensor_to_bytesio,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
import torch
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class StabilityPollStatus(str, Enum):
|
||||
finished = "finished"
|
||||
in_progress = "in_progress"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
def get_async_dummy_status(x: StabilityResultsGetResponse):
|
||||
if x.name is not None or x.errors is not None:
|
||||
return StabilityPollStatus.failed
|
||||
elif x.finish_reason is not None:
|
||||
return StabilityPollStatus.finished
|
||||
return StabilityPollStatus.in_progress
|
||||
|
||||
|
||||
class StabilityStableImageUltraNode:
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Stability AI"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
|
||||
"What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
|
||||
"elements, colors, and subjects will lead to better results. " +
|
||||
"To control the weight of a given word use the format `(word:weight)`," +
|
||||
"where `word` is the word you'd like to control the weight of and `weight`" +
|
||||
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
|
||||
"would convey a sky that was blue and green, but more green than blue."
|
||||
},
|
||||
),
|
||||
"aspect_ratio": ([x.value for x in StabilityAspectRatio],
|
||||
{
|
||||
"default": StabilityAspectRatio.ratio_1_1,
|
||||
"tooltip": "Aspect ratio of generated image.",
|
||||
},
|
||||
),
|
||||
"style_preset": (get_stability_style_presets(),
|
||||
{
|
||||
"tooltip": "Optional desired style of generated image.",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 4294967294,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image": (IO.IMAGE,),
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "A blurb of text describing what you do not wish to see in the output image. This is an advanced feature."
|
||||
},
|
||||
),
|
||||
"image_denoise": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.5,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
"tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
|
||||
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
if image is not None:
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
|
||||
else:
|
||||
image_denoise = None
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
if style_preset == "None":
|
||||
style_preset = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/generate/ultra",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityStableUltraRequest,
|
||||
response_model=StabilityStableUltraResponse,
|
||||
),
|
||||
request=StabilityStableUltraRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
seed=seed,
|
||||
strength=image_denoise,
|
||||
style_preset=style_preset,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return (returned_image,)
|
||||
|
||||
|
||||
class StabilityStableImageSD_3_5Node:
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Stability AI"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
|
||||
},
|
||||
),
|
||||
"model": ([x.value for x in Stability_SD3_5_Model],),
|
||||
"aspect_ratio": ([x.value for x in StabilityAspectRatio],
|
||||
{
|
||||
"default": StabilityAspectRatio.ratio_1_1,
|
||||
"tooltip": "Aspect ratio of generated image.",
|
||||
},
|
||||
),
|
||||
"style_preset": (get_stability_style_presets(),
|
||||
{
|
||||
"tooltip": "Optional desired style of generated image.",
|
||||
},
|
||||
),
|
||||
"cfg_scale": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 4.0,
|
||||
"min": 1.0,
|
||||
"max": 10.0,
|
||||
"step": 0.1,
|
||||
"tooltip": "How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 4294967294,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image": (IO.IMAGE,),
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
|
||||
},
|
||||
),
|
||||
"image_denoise": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.5,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
"tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
|
||||
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
# prepare image binary if image present
|
||||
image_binary = None
|
||||
mode = Stability_SD3_5_GenerationMode.text_to_image
|
||||
if image is not None:
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
|
||||
mode = Stability_SD3_5_GenerationMode.image_to_image
|
||||
aspect_ratio = None
|
||||
else:
|
||||
image_denoise = None
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
if style_preset == "None":
|
||||
style_preset = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/generate/sd3",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityStable3_5Request,
|
||||
response_model=StabilityStableUltraResponse,
|
||||
),
|
||||
request=StabilityStable3_5Request(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
seed=seed,
|
||||
strength=image_denoise,
|
||||
style_preset=style_preset,
|
||||
cfg_scale=cfg_scale,
|
||||
model=model,
|
||||
mode=mode,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return (returned_image,)
|
||||
|
||||
|
||||
class StabilityUpscaleConservativeNode:
|
||||
"""
|
||||
Upscale image with minimal alterations to 4K resolution.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Stability AI"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
|
||||
},
|
||||
),
|
||||
"creativity": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.35,
|
||||
"min": 0.2,
|
||||
"max": 0.5,
|
||||
"step": 0.01,
|
||||
"tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 4294967294,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityUpscaleConservativeRequest,
|
||||
response_model=StabilityStableUltraResponse,
|
||||
),
|
||||
request=StabilityUpscaleConservativeRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
creativity=round(creativity,2),
|
||||
seed=seed,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return (returned_image,)
|
||||
|
||||
|
||||
class StabilityUpscaleCreativeNode:
|
||||
"""
|
||||
Upscale image with minimal alterations to 4K resolution.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Stability AI"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
|
||||
},
|
||||
),
|
||||
"creativity": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.3,
|
||||
"min": 0.1,
|
||||
"max": 0.5,
|
||||
"step": 0.01,
|
||||
"tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||
},
|
||||
),
|
||||
"style_preset": (get_stability_style_presets(),
|
||||
{
|
||||
"tooltip": "Optional desired style of generated image.",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 4294967294,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "The random seed used for creating the noise.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"forceInput": True,
|
||||
"tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
|
||||
**kwargs):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||
|
||||
if not negative_prompt:
|
||||
negative_prompt = None
|
||||
if style_preset == "None":
|
||||
style_preset = None
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/upscale/creative",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityUpscaleCreativeRequest,
|
||||
response_model=StabilityAsyncResponse,
|
||||
),
|
||||
request=StabilityUpscaleCreativeRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
creativity=round(creativity,2),
|
||||
style_preset=style_preset,
|
||||
seed=seed,
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/stability/v2beta/results/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=StabilityResultsGetResponse,
|
||||
),
|
||||
poll_interval=3,
|
||||
completed_statuses=[StabilityPollStatus.finished],
|
||||
failed_statuses=[StabilityPollStatus.failed],
|
||||
status_extractor=lambda x: get_async_dummy_status(x),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_poll: StabilityResultsGetResponse = operation.execute()
|
||||
|
||||
if response_poll.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_poll.result)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return (returned_image,)
|
||||
|
||||
|
||||
class StabilityUpscaleFastNode:
|
||||
"""
|
||||
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
|
||||
"""
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "api_call"
|
||||
API_NODE = True
|
||||
CATEGORY = "api node/image/Stability AI"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
def api_call(self, image: torch.Tensor,
|
||||
**kwargs):
|
||||
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
|
||||
|
||||
files = {
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/upscale/fast",
|
||||
method=HttpMethod.POST,
|
||||
request_model=EmptyRequest,
|
||||
response_model=StabilityStableUltraResponse,
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
response_api = operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
||||
|
||||
image_data = base64.b64decode(response_api.image)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
|
||||
return (returned_image,)
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StabilityStableImageUltraNode": StabilityStableImageUltraNode,
|
||||
"StabilityStableImageSD_3_5Node": StabilityStableImageSD_3_5Node,
|
||||
"StabilityUpscaleConservativeNode": StabilityUpscaleConservativeNode,
|
||||
"StabilityUpscaleCreativeNode": StabilityUpscaleCreativeNode,
|
||||
"StabilityUpscaleFastNode": StabilityUpscaleFastNode,
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StabilityStableImageUltraNode": "Stability AI Stable Image Ultra",
|
||||
"StabilityStableImageSD_3_5Node": "Stability AI Stable Diffusion 3.5 Image",
|
||||
"StabilityUpscaleConservativeNode": "Stability AI Upscale Conservative",
|
||||
"StabilityUpscaleCreativeNode": "Stability AI Upscale Creative",
|
||||
"StabilityUpscaleFastNode": "Stability AI Upscale Fast",
|
||||
}
|
284
comfy_api_nodes/nodes_veo2.py
Normal file
284
comfy_api_nodes/nodes_veo2.py
Normal file
@ -0,0 +1,284 @@
|
||||
import io
|
||||
import logging
|
||||
import base64
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
Veo2GenVidRequest,
|
||||
Veo2GenVidResponse,
|
||||
Veo2GenVidPollRequest,
|
||||
Veo2GenVidPollResponse
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
tensor_to_base64_string
|
||||
)
|
||||
|
||||
def convert_image_to_base64(image: torch.Tensor):
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
|
||||
return tensor_to_base64_string(scaled_image)
|
||||
|
||||
class VeoVideoGenerationNode(ComfyNodeABC):
|
||||
"""
|
||||
Generates videos from text prompts using Google's Veo API.
|
||||
|
||||
This node can create videos from text descriptions and optional image inputs,
|
||||
with control over parameters like aspect ratio, duration, and more.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text description of the video",
|
||||
},
|
||||
),
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["16:9", "9:16"],
|
||||
"default": "16:9",
|
||||
"tooltip": "Aspect ratio of the output video",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"negative_prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Negative text prompt to guide what to avoid in the video",
|
||||
},
|
||||
),
|
||||
"duration_seconds": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 5,
|
||||
"min": 5,
|
||||
"max": 8,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"tooltip": "Duration of the output video in seconds",
|
||||
},
|
||||
),
|
||||
"enhance_prompt": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": True,
|
||||
"tooltip": "Whether to enhance the prompt with AI assistance",
|
||||
}
|
||||
),
|
||||
"person_generation": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"options": ["ALLOW", "BLOCK"],
|
||||
"default": "ALLOW",
|
||||
"tooltip": "Whether to allow generating people in the video",
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFF,
|
||||
"step": 1,
|
||||
"display": "number",
|
||||
"control_after_generate": True,
|
||||
"tooltip": "Seed for video generation (0 for random)",
|
||||
},
|
||||
),
|
||||
"image": (IO.IMAGE, {
|
||||
"default": None,
|
||||
"tooltip": "Optional reference image to guide video generation",
|
||||
}),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.VIDEO,)
|
||||
FUNCTION = "generate_video"
|
||||
CATEGORY = "api node/video/Veo"
|
||||
DESCRIPTION = "Generates videos from text prompts using Google's Veo API"
|
||||
API_NODE = True
|
||||
|
||||
def generate_video(
|
||||
self,
|
||||
prompt,
|
||||
aspect_ratio="16:9",
|
||||
negative_prompt="",
|
||||
duration_seconds=5,
|
||||
enhance_prompt=True,
|
||||
person_generation="ALLOW",
|
||||
seed=0,
|
||||
image=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Prepare the instances for the request
|
||||
instances = []
|
||||
|
||||
instance = {
|
||||
"prompt": prompt
|
||||
}
|
||||
|
||||
# Add image if provided
|
||||
if image is not None:
|
||||
image_base64 = convert_image_to_base64(image)
|
||||
if image_base64:
|
||||
instance["image"] = {
|
||||
"bytesBase64Encoded": image_base64,
|
||||
"mimeType": "image/png"
|
||||
}
|
||||
|
||||
instances.append(instance)
|
||||
|
||||
# Create parameters dictionary
|
||||
parameters = {
|
||||
"aspectRatio": aspect_ratio,
|
||||
"personGeneration": person_generation,
|
||||
"durationSeconds": duration_seconds,
|
||||
"enhancePrompt": enhance_prompt,
|
||||
}
|
||||
|
||||
# Add optional parameters if provided
|
||||
if negative_prompt:
|
||||
parameters["negativePrompt"] = negative_prompt
|
||||
if seed > 0:
|
||||
parameters["seed"] = seed
|
||||
|
||||
# Initial request to start video generation
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/veo/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Veo2GenVidRequest,
|
||||
response_model=Veo2GenVidResponse
|
||||
),
|
||||
request=Veo2GenVidRequest(
|
||||
instances=instances,
|
||||
parameters=parameters
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
)
|
||||
|
||||
initial_response = initial_operation.execute()
|
||||
operation_name = initial_response.name
|
||||
|
||||
logging.info(f"Veo generation started with operation name: {operation_name}")
|
||||
|
||||
# Define status extractor function
|
||||
def status_extractor(response):
|
||||
# Only return "completed" if the operation is done, regardless of success or failure
|
||||
# We'll check for errors after polling completes
|
||||
return "completed" if response.done else "pending"
|
||||
|
||||
# Define progress extractor function
|
||||
def progress_extractor(response):
|
||||
# Could be enhanced if the API provides progress information
|
||||
return None
|
||||
|
||||
# Define the polling operation
|
||||
poll_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/veo/poll",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Veo2GenVidPollRequest,
|
||||
response_model=Veo2GenVidPollResponse
|
||||
),
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=[], # No failed statuses, we'll handle errors after polling
|
||||
status_extractor=status_extractor,
|
||||
progress_extractor=progress_extractor,
|
||||
request=Veo2GenVidPollRequest(
|
||||
operationName=operation_name
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
poll_interval=5.0
|
||||
)
|
||||
|
||||
# Execute the polling operation
|
||||
poll_response = poll_operation.execute()
|
||||
|
||||
# Now check for errors in the final response
|
||||
# Check for error in poll response
|
||||
if hasattr(poll_response, 'error') and poll_response.error:
|
||||
error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})"
|
||||
logging.error(error_message)
|
||||
raise Exception(error_message)
|
||||
|
||||
# Check for RAI filtered content
|
||||
if (hasattr(poll_response.response, 'raiMediaFilteredCount') and
|
||||
poll_response.response.raiMediaFilteredCount > 0):
|
||||
|
||||
# Extract reason message if available
|
||||
if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and
|
||||
poll_response.response.raiMediaFilteredReasons):
|
||||
reason = poll_response.response.raiMediaFilteredReasons[0]
|
||||
error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
|
||||
else:
|
||||
error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
|
||||
|
||||
logging.error(error_message)
|
||||
raise Exception(error_message)
|
||||
|
||||
# Extract video data
|
||||
video_data = None
|
||||
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
|
||||
video = poll_response.response.videos[0]
|
||||
|
||||
# Check if video is provided as base64 or URL
|
||||
if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded:
|
||||
# Decode base64 string to bytes
|
||||
video_data = base64.b64decode(video.bytesBase64Encoded)
|
||||
elif hasattr(video, 'gcsUri') and video.gcsUri:
|
||||
# Download from URL
|
||||
video_url = video.gcsUri
|
||||
video_response = requests.get(video_url)
|
||||
video_data = video_response.content
|
||||
else:
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
else:
|
||||
raise Exception("Video generation completed but no video was returned")
|
||||
|
||||
if not video_data:
|
||||
raise Exception("No video data was returned")
|
||||
|
||||
logging.info("Video generation completed successfully")
|
||||
|
||||
# Convert video data to BytesIO object
|
||||
video_io = io.BytesIO(video_data)
|
||||
|
||||
# Return VideoFromFile object
|
||||
return (VideoFromFile(video_io),)
|
||||
|
||||
|
||||
# Register the node
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"VeoVideoGenerationNode": VeoVideoGenerationNode,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"VeoVideoGenerationNode": "Google Veo2 Video Generation",
|
||||
}
|
10
comfy_api_nodes/redocly-dev.yaml
Normal file
10
comfy_api_nodes/redocly-dev.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes.
|
||||
# This is used for development purposes to generate stubs for unreleased API endpoints.
|
||||
apis:
|
||||
filter:
|
||||
root: openapi.yaml
|
||||
decorators:
|
||||
filter-in:
|
||||
property: tags
|
||||
value: ['API Nodes']
|
||||
matchStrategy: all
|
10
comfy_api_nodes/redocly.yaml
Normal file
10
comfy_api_nodes/redocly.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes.
|
||||
|
||||
apis:
|
||||
filter:
|
||||
root: openapi.yaml
|
||||
decorators:
|
||||
filter-in:
|
||||
property: tags
|
||||
value: ['API Nodes', 'Released']
|
||||
matchStrategy: all
|
49
comfy_extras/nodes_ace.py
Normal file
49
comfy_extras/nodes_ace.py
Normal file
@ -0,0 +1,49 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import node_helpers
|
||||
|
||||
class TextEncodeAceStepAudio:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
"lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def encode(self, clip, tags, lyrics, lyrics_strength):
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
||||
return (conditioning, )
|
||||
|
||||
|
||||
class EmptyAceStepLatentAudio:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def generate(self, seconds, batch_size):
|
||||
length = int(seconds * 44100 / 512 / 8)
|
||||
latent = torch.zeros([batch_size, 8, 16, length], device=self.device)
|
||||
return ({"samples": latent, "type": "audio"}, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TextEncodeAceStepAudio": TextEncodeAceStepAudio,
|
||||
"EmptyAceStepLatentAudio": EmptyAceStepLatentAudio,
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import av
|
||||
import torchaudio
|
||||
import torch
|
||||
import comfy.model_management
|
||||
@ -7,7 +8,6 @@ import folder_paths
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import struct
|
||||
import random
|
||||
import hashlib
|
||||
import node_helpers
|
||||
@ -90,60 +90,118 @@ class VAEDecodeAudio:
|
||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||
|
||||
|
||||
def create_vorbis_comment_block(comment_dict, last_block):
|
||||
vendor_string = b'ComfyUI'
|
||||
vendor_length = len(vendor_string)
|
||||
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
|
||||
|
||||
comments = []
|
||||
for key, value in comment_dict.items():
|
||||
comment = f"{key}={value}".encode('utf-8')
|
||||
comments.append(struct.pack('<I', len(comment)) + comment)
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results: list[FileLocator] = []
|
||||
|
||||
user_comment_list_length = len(comments)
|
||||
user_comments = b''.join(comments)
|
||||
# Prepare metadata dictionary
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = json.dumps(prompt)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
comment_data = struct.pack('<I', vendor_length) + vendor_string + struct.pack('<I', user_comment_list_length) + user_comments
|
||||
if last_block:
|
||||
id = b'\x84'
|
||||
else:
|
||||
id = b'\x04'
|
||||
comment_block = id + struct.pack('>I', len(comment_data))[1:] + comment_data
|
||||
# Opus supported sample rates
|
||||
OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
|
||||
|
||||
return comment_block
|
||||
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
||||
output_path = os.path.join(full_output_folder, file)
|
||||
|
||||
def insert_or_replace_vorbis_comment(flac_io, comment_dict):
|
||||
if len(comment_dict) == 0:
|
||||
return flac_io
|
||||
# Use original sample rate initially
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
flac_io.seek(4)
|
||||
# Handle Opus sample rate requirements
|
||||
if format == "opus":
|
||||
if sample_rate > 48000:
|
||||
sample_rate = 48000
|
||||
elif sample_rate not in OPUS_RATES:
|
||||
# Find the next highest supported rate
|
||||
for rate in sorted(OPUS_RATES):
|
||||
if rate > sample_rate:
|
||||
sample_rate = rate
|
||||
break
|
||||
if sample_rate not in OPUS_RATES: # Fallback if still not supported
|
||||
sample_rate = 48000
|
||||
|
||||
blocks = []
|
||||
last_block = False
|
||||
# Resample if necessary
|
||||
if sample_rate != audio["sample_rate"]:
|
||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||
|
||||
while not last_block:
|
||||
header = flac_io.read(4)
|
||||
last_block = (header[0] & 0x80) != 0
|
||||
block_type = header[0] & 0x7F
|
||||
block_length = struct.unpack('>I', b'\x00' + header[1:])[0]
|
||||
block_data = flac_io.read(block_length)
|
||||
# Create in-memory WAV buffer
|
||||
wav_buffer = io.BytesIO()
|
||||
torchaudio.save(wav_buffer, waveform, sample_rate, format="WAV")
|
||||
wav_buffer.seek(0) # Rewind for reading
|
||||
|
||||
if block_type == 4 or block_type == 1:
|
||||
pass
|
||||
else:
|
||||
header = bytes([(header[0] & (~0x80))]) + header[1:]
|
||||
blocks.append(header + block_data)
|
||||
# Use PyAV to convert and add metadata
|
||||
input_container = av.open(wav_buffer)
|
||||
|
||||
blocks.append(create_vorbis_comment_block(comment_dict, last_block=True))
|
||||
# Create output with specified format
|
||||
output_buffer = io.BytesIO()
|
||||
output_container = av.open(output_buffer, mode='w', format=format)
|
||||
|
||||
new_flac_io = io.BytesIO()
|
||||
new_flac_io.write(b'fLaC')
|
||||
for block in blocks:
|
||||
new_flac_io.write(block)
|
||||
# Set metadata on the container
|
||||
for key, value in metadata.items():
|
||||
output_container.metadata[key] = value
|
||||
|
||||
new_flac_io.write(flac_io.read())
|
||||
return new_flac_io
|
||||
# Set up the output stream with appropriate properties
|
||||
input_container.streams.audio[0]
|
||||
if format == "opus":
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
||||
if quality == "64k":
|
||||
out_stream.bit_rate = 64000
|
||||
elif quality == "96k":
|
||||
out_stream.bit_rate = 96000
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "192k":
|
||||
out_stream.bit_rate = 192000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
elif format == "mp3":
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
||||
if quality == "V0":
|
||||
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||
out_stream.codec_context.qscale = 1
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
else: #format == "flac":
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
||||
|
||||
|
||||
# Copy frames from input to output
|
||||
for frame in input_container.decode(audio=0):
|
||||
frame.pts = None # Let PyAV handle timestamps
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output_container.mux(out_stream.encode(None))
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
|
||||
# Write the output to file
|
||||
output_buffer.seek(0)
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(output_buffer.getbuffer())
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
|
||||
return { "ui": { "audio": results } }
|
||||
|
||||
class SaveAudio:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
@ -153,50 +211,70 @@ class SaveAudio:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_audio"
|
||||
FUNCTION = "save_flac"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results: list[FileLocator] = []
|
||||
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = json.dumps(prompt)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
class SaveAudioMP3:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.flac"
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
buff = io.BytesIO()
|
||||
torchaudio.save(buff, waveform, audio["sample_rate"], format="FLAC")
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_mp3"
|
||||
|
||||
buff = insert_or_replace_vorbis_comment(buff, metadata)
|
||||
OUTPUT_NODE = True
|
||||
|
||||
with open(os.path.join(full_output_folder, file), 'wb') as f:
|
||||
f.write(buff.getbuffer())
|
||||
CATEGORY = "audio"
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
||||
|
||||
return { "ui": { "audio": results } }
|
||||
class SaveAudioOpus:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_opus"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
||||
|
||||
class PreviewAudio(SaveAudio):
|
||||
def __init__(self):
|
||||
@ -248,7 +326,20 @@ NODE_CLASS_MAPPINGS = {
|
||||
"VAEEncodeAudio": VAEEncodeAudio,
|
||||
"VAEDecodeAudio": VAEDecodeAudio,
|
||||
"SaveAudio": SaveAudio,
|
||||
"SaveAudioMP3": SaveAudioMP3,
|
||||
"SaveAudioOpus": SaveAudioOpus,
|
||||
"LoadAudio": LoadAudio,
|
||||
"PreviewAudio": PreviewAudio,
|
||||
"ConditioningStableAudio": ConditioningStableAudio,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"EmptyLatentAudio": "Empty Latent Audio",
|
||||
"VAEEncodeAudio": "VAE Encode Audio",
|
||||
"VAEDecodeAudio": "VAE Decode Audio",
|
||||
"PreviewAudio": "Preview Audio",
|
||||
"LoadAudio": "Load Audio",
|
||||
"SaveAudio": "Save Audio (FLAC)",
|
||||
"SaveAudioMP3": "Save Audio (MP3)",
|
||||
"SaveAudioOpus": "Save Audio (Opus)",
|
||||
}
|
||||
|
@ -20,6 +20,29 @@ class CLIPTextEncodeControlnet:
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class T5TokenizerOptions:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"min_padding": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
|
||||
"min_length": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "set_options"
|
||||
|
||||
def set_options(self, clip, min_padding, min_length):
|
||||
clip = clip.clone()
|
||||
for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]:
|
||||
clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding)
|
||||
clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length)
|
||||
|
||||
return (clip, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet
|
||||
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet,
|
||||
"T5TokenizerOptions": T5TokenizerOptions,
|
||||
}
|
||||
|
@ -1,3 +1,4 @@
|
||||
import math
|
||||
import comfy.samplers
|
||||
import comfy.sample
|
||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||
@ -249,6 +250,55 @@ class SetFirstSigma:
|
||||
sigmas[0] = sigma
|
||||
return (sigmas, )
|
||||
|
||||
class ExtendIntermediateSigmas:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"sigmas": ("SIGMAS", ),
|
||||
"steps": ("INT", {"default": 2, "min": 1, "max": 100}),
|
||||
"start_at_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20000.0, "step": 0.01, "round": False}),
|
||||
"end_at_sigma": ("FLOAT", {"default": 12.0, "min": 0.0, "max": 20000.0, "step": 0.01, "round": False}),
|
||||
"spacing": (['linear', 'cosine', 'sine'],),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/sigmas"
|
||||
|
||||
FUNCTION = "extend"
|
||||
|
||||
def extend(self, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str):
|
||||
if start_at_sigma < 0:
|
||||
start_at_sigma = float("inf")
|
||||
|
||||
interpolator = {
|
||||
'linear': lambda x: x,
|
||||
'cosine': lambda x: torch.sin(x*math.pi/2),
|
||||
'sine': lambda x: 1 - torch.cos(x*math.pi/2)
|
||||
}[spacing]
|
||||
|
||||
# linear space for our interpolation function
|
||||
x = torch.linspace(0, 1, steps + 1, device=sigmas.device)[1:-1]
|
||||
computed_spacing = interpolator(x)
|
||||
|
||||
extended_sigmas = []
|
||||
for i in range(len(sigmas) - 1):
|
||||
sigma_current = sigmas[i]
|
||||
sigma_next = sigmas[i+1]
|
||||
|
||||
extended_sigmas.append(sigma_current)
|
||||
|
||||
if end_at_sigma <= sigma_current <= start_at_sigma:
|
||||
interpolated_steps = computed_spacing * (sigma_next - sigma_current) + sigma_current
|
||||
extended_sigmas.extend(interpolated_steps.tolist())
|
||||
|
||||
# Add the last sigma value
|
||||
if len(sigmas) > 0:
|
||||
extended_sigmas.append(sigmas[-1])
|
||||
|
||||
extended_sigmas = torch.FloatTensor(extended_sigmas)
|
||||
|
||||
return (extended_sigmas,)
|
||||
|
||||
class KSamplerSelect:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -735,6 +785,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SplitSigmasDenoise": SplitSigmasDenoise,
|
||||
"FlipSigmas": FlipSigmas,
|
||||
"SetFirstSigma": SetFirstSigma,
|
||||
"ExtendIntermediateSigmas": ExtendIntermediateSigmas,
|
||||
|
||||
"CFGGuider": CFGGuider,
|
||||
"DualCFGGuider": DualCFGGuider,
|
||||
|
@ -10,6 +10,9 @@ from PIL.PngImagePlugin import PngInfo
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
from inspect import cleandoc
|
||||
|
||||
from comfy.comfy_types import FileLocator
|
||||
|
||||
@ -190,10 +193,109 @@ class SaveAnimatedPNG:
|
||||
|
||||
return { "ui": { "images": results, "animated": (True,)} }
|
||||
|
||||
class SVG:
|
||||
"""
|
||||
Stores SVG representations via a list of BytesIO objects.
|
||||
"""
|
||||
def __init__(self, data: list[BytesIO]):
|
||||
self.data = data
|
||||
|
||||
def combine(self, other: 'SVG') -> 'SVG':
|
||||
return SVG(self.data + other.data)
|
||||
|
||||
@staticmethod
|
||||
def combine_all(svgs: list['SVG']) -> 'SVG':
|
||||
all_svgs_list: list[BytesIO] = []
|
||||
for svg_item in svgs:
|
||||
all_svgs_list.extend(svg_item.data)
|
||||
return SVG(all_svgs_list)
|
||||
|
||||
class SaveSVGNode:
|
||||
"""
|
||||
Save SVG files on disk.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
RETURN_TYPES = ()
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "save_svg"
|
||||
CATEGORY = "image/save" # Changed
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"svg": ("SVG",), # Changed
|
||||
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO"
|
||||
}
|
||||
}
|
||||
|
||||
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results = list()
|
||||
|
||||
# Prepare metadata JSON
|
||||
metadata_dict = {}
|
||||
if prompt is not None:
|
||||
metadata_dict["prompt"] = prompt
|
||||
if extra_pnginfo is not None:
|
||||
metadata_dict.update(extra_pnginfo)
|
||||
|
||||
# Convert metadata to JSON string
|
||||
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
|
||||
|
||||
for batch_number, svg_bytes in enumerate(svg.data):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.svg"
|
||||
|
||||
# Read SVG content
|
||||
svg_bytes.seek(0)
|
||||
svg_content = svg_bytes.read().decode('utf-8')
|
||||
|
||||
# Inject metadata if available
|
||||
if metadata_json:
|
||||
# Create metadata element with CDATA section
|
||||
metadata_element = f""" <metadata>
|
||||
<![CDATA[
|
||||
{metadata_json}
|
||||
]]>
|
||||
</metadata>
|
||||
"""
|
||||
# Insert metadata after opening svg tag using regex with a replacement function
|
||||
def replacement(match):
|
||||
# match.group(1) contains the captured <svg> tag
|
||||
return match.group(1) + '\n' + metadata_element
|
||||
|
||||
# Apply the substitution
|
||||
svg_content = re.sub(r'(<svg[^>]*>)', replacement, svg_content, flags=re.UNICODE)
|
||||
|
||||
# Write the modified SVG to file
|
||||
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
|
||||
svg_file.write(svg_content.encode('utf-8'))
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
return { "ui": { "images": results } }
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop,
|
||||
"RepeatImageBatch": RepeatImageBatch,
|
||||
"ImageFromBatch": ImageFromBatch,
|
||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||
"SaveSVGNode": SaveSVGNode,
|
||||
}
|
||||
|
@ -2,6 +2,10 @@ import nodes
|
||||
import folder_paths
|
||||
import os
|
||||
|
||||
from comfy.comfy_types import IO
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
|
||||
|
||||
def normalize_path(path):
|
||||
return path.replace('\\', '/')
|
||||
|
||||
@ -21,8 +25,8 @@ class Load3D():
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info")
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
@ -41,7 +45,14 @@ class Load3D():
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info']
|
||||
video = None
|
||||
|
||||
if image['recording'] != "":
|
||||
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
||||
|
||||
video = VideoFromFile(recording_video_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video
|
||||
|
||||
class Load3DAnimation():
|
||||
@classmethod
|
||||
@ -59,8 +70,8 @@ class Load3DAnimation():
|
||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA")
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info")
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
|
||||
|
||||
FUNCTION = "process"
|
||||
EXPERIMENTAL = True
|
||||
@ -77,7 +88,14 @@ class Load3DAnimation():
|
||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, image['camera_info']
|
||||
video = None
|
||||
|
||||
if image['recording'] != "":
|
||||
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
||||
|
||||
video = VideoFromFile(recording_video_path)
|
||||
|
||||
return output_image, output_mask, model_file, normal_image, image['camera_info'], video
|
||||
|
||||
class Preview3D():
|
||||
@classmethod
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user