mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
Merge branch 'master' into asset-management
This commit is contained in:
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -22,7 +22,7 @@ body:
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
|
2
.github/ISSUE_TEMPLATE/user-support.yml
vendored
2
.github/ISSUE_TEMPLATE/user-support.yml
vendored
@@ -18,7 +18,7 @@ body:
|
||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||
options:
|
||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||
required: true
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Your question
|
||||
|
13
.github/workflows/stable-release.yml
vendored
13
.github/workflows/stable-release.yml
vendored
@@ -12,17 +12,17 @@ on:
|
||||
description: 'CUDA version'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "129"
|
||||
python_minor:
|
||||
description: 'Python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
python_patch:
|
||||
description: 'Python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
default: "6"
|
||||
|
||||
|
||||
jobs:
|
||||
@@ -67,6 +67,11 @@ jobs:
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
@@ -85,7 +90,7 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
|
@@ -17,19 +17,19 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "129"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
default: "6"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
12
.github/workflows/windows_release_package.yml
vendored
12
.github/workflows/windows_release_package.yml
vendored
@@ -7,19 +7,19 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "128"
|
||||
default: "129"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
default: "13"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "10"
|
||||
default: "6"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@@ -64,6 +64,10 @@ jobs:
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space
|
||||
rm ./Lib/site-packages/torch/lib/libprotoc.lib
|
||||
rm ./Lib/site-packages/torch/lib/libprotobuf.lib
|
||||
cd ..
|
||||
|
||||
git clone --depth 1 https://github.com/comfyanonymous/taesd
|
||||
@@ -82,7 +86,7 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
|
27
CODEOWNERS
27
CODEOWNERS
@@ -5,20 +5,21 @@
|
||||
# Inlined the team members for now.
|
||||
|
||||
# Maintainers
|
||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||
|
||||
# Node developers
|
||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||
/comfy_api_nodes/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||
|
22
README.md
22
README.md
@@ -211,27 +211,19 @@ This is the command to install the nightly with ROCm 6.4 which might have some p
|
||||
|
||||
### Intel GPUs (Windows and Linux)
|
||||
|
||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
|
||||
1. To install PyTorch nightly, use the following command:
|
||||
1. To install PyTorch xpu, use the following command:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu```
|
||||
|
||||
This is the command to install the Pytorch xpu nightly which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||
|
||||
2. Launch ComfyUI by running `python main.py`
|
||||
|
||||
|
||||
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
||||
|
||||
1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below:
|
||||
|
||||
```
|
||||
conda install libuv
|
||||
pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
|
||||
```
|
||||
|
||||
For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||
|
||||
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
|
||||
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||
|
||||
### NVIDIA
|
||||
|
||||
|
@@ -132,6 +132,8 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
||||
|
||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
|
||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
|
540
comfy/context_windows.py
Normal file
540
comfy/context_windows.py
Normal file
@@ -0,0 +1,540 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
import torch
|
||||
import numpy as np
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.controlnet import ControlBase
|
||||
|
||||
|
||||
class ContextWindowABC(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_tensor(self, full: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get torch.Tensor applicable to current window.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply torch.Tensor of window to the full tensor, in place. Returns reference to updated full tensor, not a copy.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
class ContextHandlerABC(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: ContextWindowABC, device=None) -> list:
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.dim = dim
|
||||
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
if dim == 0 and full.shape[dim] == 1:
|
||||
return full
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
return full[idx].to(device)
|
||||
|
||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
full[idx] += to_add
|
||||
return full
|
||||
|
||||
|
||||
class IndexListCallbacks:
|
||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
||||
EXECUTE_START = "execute_start"
|
||||
EXECUTE_CLEANUP = "execute_cleanup"
|
||||
|
||||
def init_callbacks(self):
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
name: str
|
||||
func: Callable
|
||||
|
||||
@dataclass
|
||||
class ContextFuseMethod:
|
||||
name: str
|
||||
func: Callable
|
||||
|
||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
self.context_overlap = context_overlap
|
||||
self.context_stride = context_stride
|
||||
self.closed_loop = closed_loop
|
||||
self.dim = dim
|
||||
self._step = 0
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||
if x_in.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
|
||||
if control.previous_controlnet is not None:
|
||||
self.prepare_control_objects(control.previous_controlnet, device)
|
||||
return control
|
||||
|
||||
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: IndexListContextWindow, device=None) -> list:
|
||||
if cond_in is None:
|
||||
return None
|
||||
# reuse or resize cond items to match context requirements
|
||||
resized_cond = []
|
||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||
for actual_cond in cond_in:
|
||||
resized_actual_cond = actual_cond.copy()
|
||||
# now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary
|
||||
for key in actual_cond:
|
||||
try:
|
||||
cond_item = actual_cond[key]
|
||||
if isinstance(cond_item, torch.Tensor):
|
||||
# check that tensor is the expected length - x.size(0)
|
||||
if self.dim < cond_item.ndim and cond_item.size(self.dim) == x_in.size(self.dim):
|
||||
# if so, it's subsetting time - tell controls the expected indeces so they can handle them
|
||||
actual_cond_item = window.get_tensor(cond_item)
|
||||
resized_actual_cond[key] = actual_cond_item.to(device)
|
||||
else:
|
||||
resized_actual_cond[key] = cond_item.to(device)
|
||||
# look for control
|
||||
elif key == "control":
|
||||
resized_actual_cond[key] = self.prepare_control_objects(cond_item, device)
|
||||
elif isinstance(cond_item, dict):
|
||||
new_cond_item = cond_item.copy()
|
||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||
for cond_key, cond_value in new_cond_item.items():
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||
# if has cond that is a Tensor, check if needs to be subset
|
||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
|
||||
elif cond_key == "num_video_frames": # for SVD
|
||||
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||
new_cond_item[cond_key].cond = window.context_length
|
||||
resized_actual_cond[key] = new_cond_item
|
||||
else:
|
||||
resized_actual_cond[key] = cond_item
|
||||
finally:
|
||||
del cond_item # just in case to prevent VRAM issues
|
||||
resized_cond.append(resized_actual_cond)
|
||||
return resized_cond
|
||||
|
||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||
self._step = int(matches[0].item())
|
||||
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
self.set_step(timestep, model_options)
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
|
||||
conds_final = [torch.zeros_like(x_in) for _ in conds]
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
else:
|
||||
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
for enum_window in enumerated_context_windows:
|
||||
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
||||
for result in results:
|
||||
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
||||
conds_final, counts_final, biases_final)
|
||||
try:
|
||||
# finalize conds
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
# relative is already normalized, so return as is
|
||||
del counts_final
|
||||
return conds_final
|
||||
else:
|
||||
# normalize conds via division by context usage counts
|
||||
for i in range(len(conds_final)):
|
||||
conds_final[i] /= counts_final[i]
|
||||
del counts_final
|
||||
return conds_final
|
||||
finally:
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||
model_options, device=None, first_device=None):
|
||||
results: list[ContextResults] = []
|
||||
for window_idx, window in enumerated_context_windows:
|
||||
# allow processing to end between context window executions for faster Cancel
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||
|
||||
# update exposed params
|
||||
model_options["transformer_options"]["context_window"] = window
|
||||
# get subsections of x, timestep, conds
|
||||
sub_x = window.get_tensor(x_in, device)
|
||||
sub_timestep = window.get_tensor(timestep, device, dim=0)
|
||||
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
|
||||
|
||||
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
||||
if device is not None:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||
return results
|
||||
|
||||
|
||||
def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor,
|
||||
conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
for pos, idx in enumerate(window.index_list):
|
||||
# bias is the influence of a specific index in relation to the whole context window
|
||||
bias = 1 - abs(idx - (window.index_list[0] + window.index_list[-1]) / 2) / ((window.index_list[-1] - window.index_list[0] + 1e-2) / 2)
|
||||
bias = max(1e-2, bias)
|
||||
# take weighted average relative to total bias of current idx
|
||||
for i in range(len(sub_conds_out)):
|
||||
bias_total = biases_final[i][idx]
|
||||
prev_weight = (bias_total / (bias_total + bias))
|
||||
new_weight = (bias / (bias_total + bias))
|
||||
# account for dims of tensors
|
||||
idx_window = [slice(None)] * self.dim + [idx]
|
||||
pos_window = [slice(None)] * self.dim + [pos]
|
||||
# apply new values
|
||||
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
||||
biases_final[i][idx] = bias_total + bias
|
||||
else:
|
||||
# add conds and counts based on weights of fuse method
|
||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
|
||||
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
|
||||
for i in range(len(sub_conds_out)):
|
||||
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
|
||||
window.add_window(counts_final[i], weights_tensor)
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
|
||||
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
|
||||
|
||||
|
||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
|
||||
# limit noise_shape length to context_length for more accurate vram use estimation
|
||||
model_options = kwargs.get("model_options", None)
|
||||
if model_options is None:
|
||||
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is not None:
|
||||
noise_shape = list(noise_shape)
|
||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||
return executor(model, noise_shape, *args, **kwargs)
|
||||
|
||||
|
||||
def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING,
|
||||
"ContextWindows_prepare_sampling",
|
||||
_prepare_sampling_wrapper
|
||||
)
|
||||
|
||||
|
||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||
total_dims = len(x_in.shape)
|
||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||
for _ in range(dim):
|
||||
weights_tensor = weights_tensor.unsqueeze(0)
|
||||
for _ in range(total_dims - dim - 1):
|
||||
weights_tensor = weights_tensor.unsqueeze(-1)
|
||||
return weights_tensor
|
||||
|
||||
def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
|
||||
total_dims = len(x_in.shape)
|
||||
shape = []
|
||||
for _ in range(dim):
|
||||
shape.append(1)
|
||||
shape.append(x_in.shape[dim])
|
||||
for _ in range(total_dims - dim - 1):
|
||||
shape.append(1)
|
||||
return shape
|
||||
|
||||
class ContextSchedules:
|
||||
UNIFORM_LOOPED = "looped_uniform"
|
||||
UNIFORM_STANDARD = "standard_uniform"
|
||||
STATIC_STANDARD = "standard_static"
|
||||
BATCHED = "batched"
|
||||
|
||||
|
||||
# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py
|
||||
def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
windows = []
|
||||
if num_frames < handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
|
||||
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
|
||||
# obtain uniform windows as normal, looping and all
|
||||
for context_step in 1 << np.arange(context_stride):
|
||||
pad = int(round(num_frames * ordered_halving(handler._step)))
|
||||
for j in range(
|
||||
int(ordered_halving(handler._step) * context_step) + pad,
|
||||
num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
|
||||
(handler.context_length * context_step - handler.context_overlap),
|
||||
):
|
||||
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
||||
|
||||
return windows
|
||||
|
||||
def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
# unlike looped, uniform_straight does NOT allow windows that loop back to the beginning;
|
||||
# instead, they get shifted to the corresponding end of the frames.
|
||||
# in the case that a window (shifted or not) is identical to the previous one, it gets skipped.
|
||||
windows = []
|
||||
if num_frames <= handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
|
||||
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
|
||||
# first, obtain uniform windows as normal, looping and all
|
||||
for context_step in 1 << np.arange(context_stride):
|
||||
pad = int(round(num_frames * ordered_halving(handler._step)))
|
||||
for j in range(
|
||||
int(ordered_halving(handler._step) * context_step) + pad,
|
||||
num_frames + pad + (-handler.context_overlap),
|
||||
(handler.context_length * context_step - handler.context_overlap),
|
||||
):
|
||||
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
||||
|
||||
# now that windows are created, shift any windows that loop, and delete duplicate windows
|
||||
delete_idxs = []
|
||||
win_i = 0
|
||||
while win_i < len(windows):
|
||||
# if window is rolls over itself, need to shift it
|
||||
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
|
||||
if is_roll:
|
||||
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
|
||||
shift_window_to_end(windows[win_i], num_frames=num_frames)
|
||||
# check if next window (cyclical) is missing roll_val
|
||||
if roll_val not in windows[(win_i+1) % len(windows)]:
|
||||
# need to insert new window here - just insert window starting at roll_val
|
||||
windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length)))
|
||||
# delete window if it's not unique
|
||||
for pre_i in range(0, win_i):
|
||||
if windows[win_i] == windows[pre_i]:
|
||||
delete_idxs.append(win_i)
|
||||
break
|
||||
win_i += 1
|
||||
|
||||
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
|
||||
delete_idxs.reverse()
|
||||
for i in delete_idxs:
|
||||
windows.pop(i)
|
||||
|
||||
return windows
|
||||
|
||||
|
||||
def create_windows_static_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
windows = []
|
||||
if num_frames <= handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
# always return the same set of windows
|
||||
delta = handler.context_length - handler.context_overlap
|
||||
for start_idx in range(0, num_frames, delta):
|
||||
# if past the end of frames, move start_idx back to allow same context_length
|
||||
ending = start_idx + handler.context_length
|
||||
if ending >= num_frames:
|
||||
final_delta = ending - num_frames
|
||||
final_start_idx = start_idx - final_delta
|
||||
windows.append(list(range(final_start_idx, final_start_idx + handler.context_length)))
|
||||
break
|
||||
windows.append(list(range(start_idx, start_idx + handler.context_length)))
|
||||
return windows
|
||||
|
||||
|
||||
def create_windows_batched(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||
windows = []
|
||||
if num_frames <= handler.context_length:
|
||||
windows.append(list(range(num_frames)))
|
||||
return windows
|
||||
# always return the same set of windows;
|
||||
# no overlap, just cut up based on context_length;
|
||||
# last window size will be different if num_frames % opts.context_length != 0
|
||||
for start_idx in range(0, num_frames, handler.context_length):
|
||||
windows.append(list(range(start_idx, min(start_idx + handler.context_length, num_frames))))
|
||||
return windows
|
||||
|
||||
|
||||
def create_windows_default(num_frames: int, handler: IndexListContextHandler):
|
||||
return [list(range(num_frames))]
|
||||
|
||||
|
||||
CONTEXT_MAPPING = {
|
||||
ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped,
|
||||
ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard,
|
||||
ContextSchedules.STATIC_STANDARD: create_windows_static_standard,
|
||||
ContextSchedules.BATCHED: create_windows_batched,
|
||||
}
|
||||
|
||||
|
||||
def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
|
||||
func = CONTEXT_MAPPING.get(context_schedule, None)
|
||||
if func is None:
|
||||
raise ValueError(f"Unknown context_schedule '{context_schedule}'.")
|
||||
return ContextSchedule(context_schedule, func)
|
||||
|
||||
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
|
||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
|
||||
|
||||
|
||||
def create_weights_flat(length: int, **kwargs) -> list[float]:
|
||||
# weight is the same for all
|
||||
return [1.0] * length
|
||||
|
||||
def create_weights_pyramid(length: int, **kwargs) -> list[float]:
|
||||
# weight is based on the distance away from the edge of the context window;
|
||||
# based on weighted average concept in FreeNoise paper
|
||||
if length % 2 == 0:
|
||||
max_weight = length // 2
|
||||
weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))
|
||||
else:
|
||||
max_weight = (length + 1) // 2
|
||||
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
|
||||
return weight_sequence
|
||||
|
||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
|
||||
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
|
||||
# only expected overlap is given different weights
|
||||
weights_torch = torch.ones((length))
|
||||
# blend left-side on all except first window
|
||||
if min(idxs) > 0:
|
||||
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
|
||||
weights_torch[:handler.context_overlap] = ramp_up
|
||||
# blend right-side on all except last window
|
||||
if max(idxs) < full_length-1:
|
||||
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
|
||||
weights_torch[-handler.context_overlap:] = ramp_down
|
||||
return weights_torch
|
||||
|
||||
class ContextFuseMethods:
|
||||
FLAT = "flat"
|
||||
PYRAMID = "pyramid"
|
||||
RELATIVE = "relative"
|
||||
OVERLAP_LINEAR = "overlap-linear"
|
||||
|
||||
LIST = [PYRAMID, FLAT, OVERLAP_LINEAR]
|
||||
LIST_STATIC = [PYRAMID, RELATIVE, FLAT, OVERLAP_LINEAR]
|
||||
|
||||
|
||||
FUSE_MAPPING = {
|
||||
ContextFuseMethods.FLAT: create_weights_flat,
|
||||
ContextFuseMethods.PYRAMID: create_weights_pyramid,
|
||||
ContextFuseMethods.RELATIVE: create_weights_pyramid,
|
||||
ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear,
|
||||
}
|
||||
|
||||
def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod:
|
||||
func = FUSE_MAPPING.get(fuse_method, None)
|
||||
if func is None:
|
||||
raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
|
||||
return ContextFuseMethod(fuse_method, func)
|
||||
|
||||
# Returns fraction that has denominator that is a power of 2
|
||||
def ordered_halving(val):
|
||||
# get binary value, padded with 0s for 64 bits
|
||||
bin_str = f"{val:064b}"
|
||||
# flip binary value, padding included
|
||||
bin_flip = bin_str[::-1]
|
||||
# convert binary to int
|
||||
as_int = int(bin_flip, 2)
|
||||
# divide by 1 << 64, equivalent to 2**64, or 18446744073709551616,
|
||||
# or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's)
|
||||
return as_int / (1 << 64)
|
||||
|
||||
|
||||
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
|
||||
all_indexes = list(range(num_frames))
|
||||
for w in windows:
|
||||
for val in w:
|
||||
try:
|
||||
all_indexes.remove(val)
|
||||
except ValueError:
|
||||
pass
|
||||
return all_indexes
|
||||
|
||||
|
||||
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
|
||||
prev_val = -1
|
||||
for i, val in enumerate(window):
|
||||
val = val % num_frames
|
||||
if val < prev_val:
|
||||
return True, i
|
||||
prev_val = val
|
||||
return False, -1
|
||||
|
||||
|
||||
def shift_window_to_start(window: list[int], num_frames: int):
|
||||
start_val = window[0]
|
||||
for i in range(len(window)):
|
||||
# 1) subtract each element by start_val to move vals relative to the start of all frames
|
||||
# 2) add num_frames and take modulus to get adjusted vals
|
||||
window[i] = ((window[i] - start_val) + num_frames) % num_frames
|
||||
|
||||
|
||||
def shift_window_to_end(window: list[int], num_frames: int):
|
||||
# 1) shift window to start
|
||||
shift_window_to_start(window, num_frames)
|
||||
end_val = window[-1]
|
||||
end_delta = num_frames - end_val - 1
|
||||
for i in range(len(window)):
|
||||
# 2) add end_delta to each val to slide windows to end
|
||||
window[i] = window[i] + end_delta
|
@@ -224,20 +224,28 @@ class Flux(nn.Module):
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
|
||||
for ref in ref_latents:
|
||||
if index_ref_method:
|
||||
index += 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
else:
|
||||
index = 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||
w_offset = w
|
||||
else:
|
||||
h_offset = h
|
||||
|
||||
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
h = max(h, ref.shape[-2] + h_offset)
|
||||
w = max(w, ref.shape[-1] + w_offset)
|
||||
|
||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = out[:, :img_tokens]
|
||||
|
@@ -178,7 +178,7 @@ class FourierEmbedder(nn.Module):
|
||||
|
||||
class CrossAttentionProcessor:
|
||||
def __call__(self, attn, q, k, v):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||
return out
|
||||
|
||||
|
||||
|
@@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
@@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
if mask.shape[0] > 1:
|
||||
m = mask[i : i + SDP_BATCH_LIMIT]
|
||||
|
||||
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
|
||||
out[i : i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention(
|
||||
q[i : i + SDP_BATCH_LIMIT],
|
||||
k[i : i + SDP_BATCH_LIMIT],
|
||||
v[i : i + SDP_BATCH_LIMIT],
|
||||
|
@@ -285,7 +285,7 @@ def pytorch_attention(q, k, v):
|
||||
)
|
||||
|
||||
try:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(orig_shape)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
|
@@ -333,21 +333,25 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def pos_embeds(self, x, context):
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
txt_start = round(max(h_len, w_len))
|
||||
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -356,19 +360,46 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
ref_latents=None,
|
||||
transformer_options={},
|
||||
**kwargs
|
||||
):
|
||||
timestep = timesteps
|
||||
encoder_hidden_states = context
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
|
||||
image_rotary_emb = self.pos_embeds(x, context)
|
||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||
num_embeds = hidden_states.shape[1]
|
||||
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
|
||||
for ref in ref_latents:
|
||||
if index_ref_method:
|
||||
index += 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
else:
|
||||
index = 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||
w_offset = w
|
||||
else:
|
||||
h_offset = h
|
||||
h = max(h, ref.shape[-2] + h_offset)
|
||||
w = max(w, ref.shape[-1] + w_offset)
|
||||
|
||||
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size)))
|
||||
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
@@ -383,7 +414,19 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||
)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -395,6 +438,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
||||
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
||||
|
@@ -391,6 +391,7 @@ class WanModel(torch.nn.Module):
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
flf_pos_embed_token_number=None,
|
||||
in_dim_ref_conv=None,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@@ -484,6 +485,11 @@ class WanModel(torch.nn.Module):
|
||||
else:
|
||||
self.img_emb = None
|
||||
|
||||
if in_dim_ref_conv is not None:
|
||||
self.ref_conv = operations.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:], device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
else:
|
||||
self.ref_conv = None
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
x,
|
||||
@@ -526,6 +532,13 @@ class WanModel(torch.nn.Module):
|
||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
full_ref = None
|
||||
if self.ref_conv is not None:
|
||||
full_ref = kwargs.get("reference_latent", None)
|
||||
if full_ref is not None:
|
||||
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
|
||||
x = torch.concat((full_ref, x), dim=1)
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
@@ -552,6 +565,9 @@ class WanModel(torch.nn.Module):
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
if full_ref is not None:
|
||||
x = x[:, full_ref.shape[1]:]
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
@@ -570,6 +586,9 @@ class WanModel(torch.nn.Module):
|
||||
x = torch.cat([x, time_dim_concat], dim=2)
|
||||
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
|
||||
|
||||
if self.ref_conv is not None and "reference_latent" in kwargs:
|
||||
t_len += 1
|
||||
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
@@ -749,7 +768,12 @@ class CameraWanModel(WanModel):
|
||||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
if model_type == 'camera':
|
||||
model_type = 'i2v'
|
||||
else:
|
||||
model_type = 't2v'
|
||||
|
||||
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
|
||||
|
@@ -890,6 +890,10 @@ class Flux(BaseModel):
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
|
||||
ref_latents_method = kwargs.get("reference_latents_method", None)
|
||||
if ref_latents_method is not None:
|
||||
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
@@ -1124,6 +1128,10 @@ class WAN21(BaseModel):
|
||||
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
|
||||
concat_mask_index = kwargs.get("concat_mask_index", 0)
|
||||
if concat_mask_index != 0:
|
||||
return torch.cat((image[:, :concat_mask_index], mask, image[:, concat_mask_index:]), dim=1)
|
||||
else:
|
||||
return torch.cat((mask, image), dim=1)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
@@ -1140,6 +1148,10 @@ class WAN21(BaseModel):
|
||||
if time_dim_concat is not None:
|
||||
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
|
||||
|
||||
reference_latents = kwargs.get("reference_latents", None)
|
||||
if reference_latents is not None:
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -1319,4 +1331,14 @@ class QwenImage(BaseModel):
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
latents = []
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
|
||||
ref_latents_method = kwargs.get("reference_latents_method", None)
|
||||
if ref_latents_method is not None:
|
||||
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
|
||||
return out
|
||||
|
@@ -364,7 +364,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
||||
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "camera"
|
||||
else:
|
||||
dit_config["model_type"] = "camera_2.2"
|
||||
else:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
@@ -373,6 +376,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
||||
if flf_weight is not None:
|
||||
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
||||
|
||||
ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix))
|
||||
if ref_conv_weight is not None:
|
||||
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||
|
@@ -78,7 +78,6 @@ try:
|
||||
torch_version = torch.version.__version__
|
||||
temp = torch_version.split(".")
|
||||
torch_version_numeric = (int(temp[0]), int(temp[1]))
|
||||
xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available()
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -102,10 +101,14 @@ if args.directml is not None:
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa: F401
|
||||
_ = torch.xpu.device_count()
|
||||
xpu_available = xpu_available or torch.xpu.is_available()
|
||||
except:
|
||||
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||
pass
|
||||
|
||||
try:
|
||||
_ = torch.xpu.device_count()
|
||||
xpu_available = torch.xpu.is_available()
|
||||
except:
|
||||
xpu_available = False
|
||||
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
@@ -946,10 +949,12 @@ def pick_weight_dtype(dtype, fallback_dtype, device=None):
|
||||
return dtype
|
||||
|
||||
def device_supports_non_blocking(device):
|
||||
if args.force_non_blocking:
|
||||
return True
|
||||
if is_device_mps(device):
|
||||
return False #pytorch bug? mps doesn't support non blocking
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
if is_intel_xpu(): #xpu does support non blocking but it is slower on iGPUs for some reason so disable by default until situation changes
|
||||
return False
|
||||
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
||||
return False
|
||||
if directml_enabled:
|
||||
@@ -1282,10 +1287,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
if torch_version_numeric < (2, 6):
|
||||
if torch_version_numeric < (2, 3):
|
||||
return True
|
||||
else:
|
||||
return torch.xpu.get_device_capability(device)['has_bfloat16_conversions']
|
||||
return torch.xpu.is_bf16_supported()
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
|
26
comfy/ops.py
26
comfy/ops.py
@@ -24,6 +24,32 @@ import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import contextlib
|
||||
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
import inspect
|
||||
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||
SDPA_BACKEND_PRIORITY = [
|
||||
SDPBackend.FLASH_ATTENTION,
|
||||
SDPBackend.EFFICIENT_ATTENTION,
|
||||
SDPBackend.MATH,
|
||||
]
|
||||
|
||||
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
else:
|
||||
logging.warning("Torch version too old to set sdpa backend priority.")
|
||||
except (ModuleNotFoundError, TypeError):
|
||||
logging.warning("Could not set sdpa backend priority.")
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import numbers
|
||||
import logging
|
||||
|
||||
RMSNorm = None
|
||||
|
||||
@@ -9,6 +10,7 @@ try:
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
logging.warning("Please update pytorch to use native RMSNorm")
|
||||
|
||||
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
|
@@ -149,7 +149,7 @@ def cleanup_models(conds, models):
|
||||
|
||||
cleanup_additional_models(set(control_cleanup))
|
||||
|
||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||
def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
||||
'''
|
||||
Registers hooks from conds.
|
||||
'''
|
||||
@@ -158,8 +158,8 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
||||
for k in conds:
|
||||
get_hooks_from_cond(conds[k], hooks)
|
||||
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
|
||||
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("wrappers", {}), model.wrappers, copy_dict1=False)
|
||||
comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("callbacks", {}), model.callbacks, copy_dict1=False)
|
||||
# begin registering hooks
|
||||
registered = comfy.hooks.HookGroup()
|
||||
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
|
||||
|
@@ -16,6 +16,7 @@ import comfy.sampler_helpers
|
||||
import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import comfy.context_windows
|
||||
import scipy.stats
|
||||
import numpy
|
||||
|
||||
@@ -198,14 +199,20 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
|
||||
def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options: dict[str]):
|
||||
handler: comfy.context_windows.ContextHandlerABC = model_options.get("context_handler", None)
|
||||
if handler is None or not handler.should_use_context(model, conds, x_in, timestep, model_options):
|
||||
return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options)
|
||||
return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_calc_cond_batch,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
|
@@ -1046,6 +1046,18 @@ class WAN21_Camera(WAN21_T2V):
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_Camera(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "camera_2.2",
|
||||
"in_dim": 36,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN21_Vace(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@@ -1260,6 +1272,6 @@ class QwenImage(supported_models_base.BASE):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
@@ -12,7 +12,7 @@ import torch
|
||||
try:
|
||||
import torchaudio
|
||||
TORCH_AUDIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
except:
|
||||
TORCH_AUDIO_AVAILABLE = False
|
||||
from PIL import Image as PILImage
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
@@ -1690,7 +1690,11 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||
):
|
||||
self.validate_prompt(prompt, negative_prompt)
|
||||
|
||||
if image is not None:
|
||||
if image is None:
|
||||
image_type = None
|
||||
elif model_name == KlingImageGenModelName.kling_v1:
|
||||
raise ValueError(f"The model {KlingImageGenModelName.kling_v1.value} does not support reference images.")
|
||||
else:
|
||||
image = tensor_to_base64_string(image)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
|
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
import random
|
||||
import torch
|
||||
from comfy_api_nodes.util.validation_utils import (
|
||||
get_image_dimensions,
|
||||
@@ -208,20 +207,29 @@ def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
|
||||
def _validate_video_dimensions(width: int, height: int) -> None:
|
||||
"""Validates video dimensions meet Moonvalley V2V requirements."""
|
||||
supported_resolutions = {
|
||||
(1920, 1080), (1080, 1920), (1152, 1152),
|
||||
(1536, 1152), (1152, 1536)
|
||||
(1920, 1080),
|
||||
(1080, 1920),
|
||||
(1152, 1152),
|
||||
(1536, 1152),
|
||||
(1152, 1536),
|
||||
}
|
||||
|
||||
if (width, height) not in supported_resolutions:
|
||||
supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)])
|
||||
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||
supported_list = ", ".join(
|
||||
[f"{w}x{h}" for w, h in sorted(supported_resolutions)]
|
||||
)
|
||||
raise ValueError(
|
||||
f"Resolution {width}x{height} not supported. Supported: {supported_list}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_container_format(video: VideoInput) -> None:
|
||||
"""Validates video container format is MP4."""
|
||||
container_format = video.get_container_format()
|
||||
if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']:
|
||||
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||
raise ValueError(
|
||||
f"Only MP4 container format supported. Got: {container_format}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
@@ -244,7 +252,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
return video
|
||||
|
||||
|
||||
|
||||
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
"""
|
||||
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
||||
@@ -302,7 +309,9 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
# Calculate target frame count that's divisible by 16
|
||||
fps = input_container.streams.video[0].average_rate
|
||||
estimated_frames = int(duration_sec * fps)
|
||||
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
|
||||
target_frames = (
|
||||
estimated_frames // 16
|
||||
) * 16 # Round down to nearest multiple of 16
|
||||
|
||||
if target_frames == 0:
|
||||
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
||||
@@ -424,7 +433,7 @@ class BaseMoonvalleyVideoNode:
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts",
|
||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
),
|
||||
"resolution": (
|
||||
IO.COMBO,
|
||||
@@ -441,12 +450,11 @@ class BaseMoonvalleyVideoNode:
|
||||
"tooltip": "Resolution of the output video",
|
||||
},
|
||||
),
|
||||
# "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}),
|
||||
"prompt_adherence": model_field_to_node_input(
|
||||
IO.FLOAT,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
"guidance_scale",
|
||||
default=7.0,
|
||||
default=10.0,
|
||||
step=1,
|
||||
min=1,
|
||||
max=20,
|
||||
@@ -455,13 +463,12 @@ class BaseMoonvalleyVideoNode:
|
||||
IO.INT,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
"seed",
|
||||
default=random.randint(0, 2**32 - 1),
|
||||
default=9,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
display="number",
|
||||
tooltip="Random seed value",
|
||||
control_after_generate=True,
|
||||
),
|
||||
"steps": model_field_to_node_input(
|
||||
IO.INT,
|
||||
@@ -532,9 +539,11 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
||||
# Get MIME type from tensor - assuming PNG format for image tensors
|
||||
mime_type = "image/png"
|
||||
|
||||
image_url = (await upload_images_to_comfyapi(
|
||||
image_url = (
|
||||
await upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
|
||||
))[0]
|
||||
)
|
||||
)[0]
|
||||
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
image_url=image_url, prompt_text=prompt, inference_params=inference_params
|
||||
@@ -570,17 +579,39 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": model_field_to_node_input(
|
||||
IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text",
|
||||
multiline=True
|
||||
IO.STRING,
|
||||
MoonvalleyVideoToVideoRequest,
|
||||
"prompt_text",
|
||||
multiline=True,
|
||||
),
|
||||
"negative_prompt": model_field_to_node_input(
|
||||
IO.STRING,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts"
|
||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
),
|
||||
"seed": model_field_to_node_input(
|
||||
IO.INT,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
"seed",
|
||||
default=9,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
display="number",
|
||||
tooltip="Random seed value",
|
||||
control_after_generate=False,
|
||||
),
|
||||
"prompt_adherence": model_field_to_node_input(
|
||||
IO.FLOAT,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
"guidance_scale",
|
||||
default=10.0,
|
||||
step=1,
|
||||
min=1,
|
||||
max=20,
|
||||
),
|
||||
"seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
@@ -588,7 +619,14 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
"optional": {
|
||||
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}),
|
||||
"video": (
|
||||
IO.VIDEO,
|
||||
{
|
||||
"default": "",
|
||||
"multiline": False,
|
||||
"tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
|
||||
},
|
||||
),
|
||||
"control_type": (
|
||||
["Motion Transfer", "Pose Transfer"],
|
||||
{"default": "Motion Transfer"},
|
||||
@@ -602,8 +640,14 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
"max": 100,
|
||||
"tooltip": "Only used if control_type is 'Motion Transfer'",
|
||||
},
|
||||
)
|
||||
}
|
||||
),
|
||||
"image": model_field_to_node_input(
|
||||
IO.IMAGE,
|
||||
MoonvalleyTextToVideoRequest,
|
||||
"image_url",
|
||||
tooltip="The reference image used to generate the video",
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VIDEO",)
|
||||
@@ -613,6 +657,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||
):
|
||||
video = kwargs.get("video")
|
||||
image = kwargs.get("image", None)
|
||||
|
||||
if not video:
|
||||
raise MoonvalleyApiError("video is required")
|
||||
@@ -620,8 +665,16 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
video_url = ""
|
||||
if video:
|
||||
validated_video = validate_video_to_video_input(video)
|
||||
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
|
||||
video_url = await upload_video_to_comfyapi(
|
||||
validated_video, auth_kwargs=kwargs
|
||||
)
|
||||
mime_type = "image/png"
|
||||
|
||||
if not image is None:
|
||||
validate_input_image(image, with_frame_conditioning=True)
|
||||
image_url = await upload_images_to_comfyapi(
|
||||
image=image, auth_kwargs=kwargs, max_images=1, mime_type=mime_type
|
||||
)
|
||||
control_type = kwargs.get("control_type")
|
||||
motion_intensity = kwargs.get("motion_intensity")
|
||||
|
||||
@@ -631,12 +684,12 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
# Only include motion_intensity for Motion Transfer
|
||||
control_params = {}
|
||||
if control_type == "Motion Transfer" and motion_intensity is not None:
|
||||
control_params['motion_intensity'] = motion_intensity
|
||||
control_params["motion_intensity"] = motion_intensity
|
||||
|
||||
inference_params=MoonvalleyVideoToVideoInferenceParams(
|
||||
inference_params = MoonvalleyVideoToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
seed=kwargs.get("seed"),
|
||||
control_params=control_params
|
||||
control_params=control_params,
|
||||
)
|
||||
|
||||
control = self.parseControlParameter(control_type)
|
||||
@@ -647,6 +700,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
||||
prompt_text=prompt,
|
||||
inference_params=inference_params,
|
||||
)
|
||||
request.image_url = image_url if not image is None else None
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
@@ -694,7 +748,7 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
|
||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
|
||||
|
||||
inference_params=MoonvalleyTextToVideoInferenceParams(
|
||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=kwargs.get("steps"),
|
||||
seed=kwargs.get("seed"),
|
||||
|
@@ -464,8 +464,6 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
path = "/proxy/openai/images/generations"
|
||||
content_type = "application/json"
|
||||
request_class = OpenAIImageGenerationRequest
|
||||
img_binaries = []
|
||||
mask_binary = None
|
||||
files = []
|
||||
|
||||
if image is not None:
|
||||
@@ -484,14 +482,11 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
img_byte_arr.seek(0)
|
||||
img_binary = img_byte_arr
|
||||
img_binary.name = f"image_{i}.png"
|
||||
|
||||
img_binaries.append(img_binary)
|
||||
if batch_size == 1:
|
||||
files.append(("image", img_binary))
|
||||
files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||
else:
|
||||
files.append(("image[]", img_binary))
|
||||
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||
|
||||
if mask is not None:
|
||||
if image is None:
|
||||
@@ -511,9 +506,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
||||
mask_img_byte_arr = io.BytesIO()
|
||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||
mask_img_byte_arr.seek(0)
|
||||
mask_binary = mask_img_byte_arr
|
||||
mask_binary.name = "mask.png"
|
||||
files.append(("mask", mask_binary))
|
||||
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
||||
|
||||
# Build the operation
|
||||
operation = SynchronousOperation(
|
||||
|
@@ -346,6 +346,24 @@ class LoadAudio:
|
||||
return "Invalid audio file: {}".format(audio)
|
||||
return True
|
||||
|
||||
class RecordAudio:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"audio": ("AUDIO_RECORD", {})}}
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
RETURN_TYPES = ("AUDIO", )
|
||||
FUNCTION = "load"
|
||||
|
||||
def load(self, audio):
|
||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||
|
||||
waveform, sample_rate = torchaudio.load(audio_path)
|
||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||
return (audio, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyLatentAudio": EmptyLatentAudio,
|
||||
"VAEEncodeAudio": VAEEncodeAudio,
|
||||
@@ -356,6 +374,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"LoadAudio": LoadAudio,
|
||||
"PreviewAudio": PreviewAudio,
|
||||
"ConditioningStableAudio": ConditioningStableAudio,
|
||||
"RecordAudio": RecordAudio,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -367,4 +386,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SaveAudio": "Save Audio (FLAC)",
|
||||
"SaveAudioMP3": "Save Audio (MP3)",
|
||||
"SaveAudioOpus": "Save Audio (Opus)",
|
||||
"RecordAudio": "Record Audio",
|
||||
}
|
||||
|
89
comfy_extras/nodes_context_windows.py
Normal file
89
comfy_extras/nodes_context_windows.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import comfy.context_windows
|
||||
import nodes
|
||||
|
||||
|
||||
class ContextWindowsManualNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="ContextWindowsManual",
|
||||
display_name="Context Windows (Manual)",
|
||||
category="context",
|
||||
description="Manually set context windows.",
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."),
|
||||
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."),
|
||||
io.Combo.Input("context_schedule", options=[
|
||||
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
||||
comfy.context_windows.ContextSchedules.BATCHED,
|
||||
], tooltip="The stride of the context window."),
|
||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
|
||||
model = model.clone()
|
||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||
fuse_method=comfy.context_windows.get_matching_fuse_method(fuse_method),
|
||||
context_length=context_length,
|
||||
context_overlap=context_overlap,
|
||||
context_stride=context_stride,
|
||||
closed_loop=closed_loop,
|
||||
dim=dim)
|
||||
# make memory usage calculation only take into account the context window latents
|
||||
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
class WanContextWindowsManualNode(ContextWindowsManualNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
schema = super().define_schema()
|
||||
schema.node_id = "WanContextWindowsManual"
|
||||
schema.display_name = "WAN Context Windows (Manual)"
|
||||
schema.description = "Manually set context windows for WAN-like models (dim=2)."
|
||||
schema.inputs = [
|
||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window."),
|
||||
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window."),
|
||||
io.Combo.Input("context_schedule", options=[
|
||||
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
||||
comfy.context_windows.ContextSchedules.BATCHED,
|
||||
], tooltip="The stride of the context window."),
|
||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
]
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
|
||||
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
||||
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
|
||||
|
||||
|
||||
class ContextWindowsExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
ContextWindowsManualNode,
|
||||
WanContextWindowsManualNode,
|
||||
]
|
||||
|
||||
def comfy_entrypoint():
|
||||
return ContextWindowsExtension()
|
@@ -100,9 +100,28 @@ class FluxKontextImageScale:
|
||||
return (image, )
|
||||
|
||||
|
||||
class FluxKontextMultiReferenceLatentMethod:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"conditioning": ("CONDITIONING", ),
|
||||
"reference_latents_method": (("offset", "index"), ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "append"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
CATEGORY = "advanced/conditioning/flux"
|
||||
|
||||
def append(self, conditioning, reference_latents_method):
|
||||
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
|
||||
return (c, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
||||
"FluxGuidance": FluxGuidance,
|
||||
"FluxDisableGuidance": FluxDisableGuidance,
|
||||
"FluxKontextImageScale": FluxKontextImageScale,
|
||||
"FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod,
|
||||
}
|
||||
|
@@ -9,29 +9,35 @@ import comfy.clip_vision
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
class WanImageToVideo:
|
||||
class WanImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanImageToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
@@ -51,32 +57,36 @@ class WanImageToVideo:
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class WanFunControlToVideo:
|
||||
class WanFunControlToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"control_video": ("IMAGE", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanFunControlToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.Image.Input("control_video", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||
@@ -101,32 +111,96 @@ class WanFunControlToVideo:
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
class WanFirstLastFrameToVideo:
|
||||
class Wan22FunControlToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ),
|
||||
"clip_vision_end_image": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"end_image": ("IMAGE", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Wan22FunControlToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("ref_image", optional=True),
|
||||
io.Image.Input("control_video", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
|
||||
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
||||
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None):
|
||||
ref_latent = None
|
||||
if ref_image is not None:
|
||||
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
ref_latent = vae.encode(ref_image[:, :, :, :3])
|
||||
|
||||
if control_video is not None:
|
||||
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(control_video[:, :, :, :3])
|
||||
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
|
||||
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
|
||||
|
||||
if ref_latent is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
class WanFirstLastFrameToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanFirstLastFrameToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.ClipVisionOutput.Input("clip_vision_start_image", optional=True),
|
||||
io.ClipVisionOutput.Input("clip_vision_end_image", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.Image.Input("end_image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
@@ -167,62 +241,70 @@ class WanFirstLastFrameToVideo:
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class WanFunInpaintToVideo:
|
||||
class WanFunInpaintToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"end_image": ("IMAGE", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanFunInpaintToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.Image.Input("end_image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||
flfv = WanFirstLastFrameToVideo()
|
||||
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
||||
return flfv.execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
||||
|
||||
|
||||
class WanVaceToVideo:
|
||||
class WanVaceToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||
},
|
||||
"optional": {"control_video": ("IMAGE", ),
|
||||
"control_masks": ("MASK", ),
|
||||
"reference_image": ("IMAGE", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanVaceToVideo",
|
||||
category="conditioning/video_models",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=1000.0, step=0.01),
|
||||
io.Image.Input("control_video", optional=True),
|
||||
io.Mask.Input("control_masks", optional=True),
|
||||
io.Image.Input("reference_image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
io.Int.Output(display_name="trim_latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent", "trim_latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None) -> io.NodeOutput:
|
||||
latent_length = ((length - 1) // 4) + 1
|
||||
if control_video is not None:
|
||||
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
@@ -279,52 +361,59 @@ class WanVaceToVideo:
|
||||
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent, trim_latent)
|
||||
return io.NodeOutput(positive, negative, out_latent, trim_latent)
|
||||
|
||||
class TrimVideoLatent:
|
||||
class TrimVideoLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TrimVideoLatent",
|
||||
category="latent/video",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
io.Int.Input("trim_amount", default=0, min=0, max=99999),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/video"
|
||||
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def op(self, samples, trim_amount):
|
||||
@classmethod
|
||||
def execute(cls, samples, trim_amount) -> io.NodeOutput:
|
||||
samples_out = samples.copy()
|
||||
|
||||
s1 = samples["samples"]
|
||||
samples_out["samples"] = s1[:, :, trim_amount:]
|
||||
return (samples_out,)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class WanCameraImageToVideo:
|
||||
class WanCameraImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"start_image": ("IMAGE", ),
|
||||
"camera_conditions": ("WAN_CAMERA_EMBEDDING", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanCameraImageToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.WanCameraEmbedding.Input("camera_conditions", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||
@@ -333,9 +422,12 @@ class WanCameraImageToVideo:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
||||
concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
||||
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
||||
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask})
|
||||
|
||||
if camera_conditions is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions})
|
||||
@@ -347,29 +439,34 @@ class WanCameraImageToVideo:
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
class WanPhantomSubjectToVideo:
|
||||
class WanPhantomSubjectToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"images": ("IMAGE", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanPhantomSubjectToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("images", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative_text"),
|
||||
io.Conditioning.Output(display_name="negative_img_text"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, width, height, length, batch_size, images):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, images) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
cond2 = negative
|
||||
if images is not None:
|
||||
@@ -385,7 +482,7 @@ class WanPhantomSubjectToVideo:
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, cond2, negative, out_latent)
|
||||
return io.NodeOutput(positive, cond2, negative, out_latent)
|
||||
|
||||
def parse_json_tracks(tracks):
|
||||
"""Parse JSON track data into a standardized format"""
|
||||
@@ -598,39 +695,41 @@ def patch_motion(
|
||||
|
||||
return out_mask_full, out_feature_full
|
||||
|
||||
class WanTrackToVideo:
|
||||
class WanTrackToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"tracks": ("STRING", {"multiline": True, "default": "[]"}),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||
"topk": ("INT", {"default": 2, "min": 1, "max": 10}),
|
||||
"start_image": ("IMAGE", ),
|
||||
},
|
||||
"optional": {
|
||||
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanTrackToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.String.Input("tracks", multiline=True, default="[]"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Float.Input("temperature", default=220.0, min=1.0, max=1000.0, step=0.1),
|
||||
io.Int.Input("topk", default=2, min=1, max=10),
|
||||
io.Image.Input("start_image"),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, negative, vae, tracks, width, height, length, batch_size,
|
||||
temperature, topk, start_image=None, clip_vision_output=None):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, tracks, width, height, length, batch_size,
|
||||
temperature, topk, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||
|
||||
tracks_data = parse_json_tracks(tracks)
|
||||
|
||||
if not tracks_data:
|
||||
return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output)
|
||||
return WanImageToVideo().execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output)
|
||||
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device())
|
||||
@@ -684,34 +783,36 @@ class WanTrackToVideo:
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, negative, out_latent)
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class Wan22ImageToVideoLatent:
|
||||
class Wan22ImageToVideoLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 1280, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"height": ("INT", {"default": 704, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||
"length": ("INT", {"default": 49, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"start_image": ("IMAGE", ),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Wan22ImageToVideoLatent",
|
||||
category="conditioning/inpaint",
|
||||
inputs=[
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=1280, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("height", default=704, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/inpaint"
|
||||
|
||||
def encode(self, vae, width, height, length, batch_size, start_image=None):
|
||||
@classmethod
|
||||
def execute(cls, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
||||
|
||||
if start_image is None:
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (out_latent,)
|
||||
return io.NodeOutput(out_latent)
|
||||
|
||||
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||
|
||||
@@ -726,18 +827,25 @@ class Wan22ImageToVideoLatent:
|
||||
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
|
||||
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||
return (out_latent,)
|
||||
return io.NodeOutput(out_latent)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanTrackToVideo": WanTrackToVideo,
|
||||
"WanImageToVideo": WanImageToVideo,
|
||||
"WanFunControlToVideo": WanFunControlToVideo,
|
||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
||||
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
||||
"WanVaceToVideo": WanVaceToVideo,
|
||||
"TrimVideoLatent": TrimVideoLatent,
|
||||
"WanCameraImageToVideo": WanCameraImageToVideo,
|
||||
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo,
|
||||
"Wan22ImageToVideoLatent": Wan22ImageToVideoLatent,
|
||||
}
|
||||
class WanExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
WanTrackToVideo,
|
||||
WanImageToVideo,
|
||||
WanFunControlToVideo,
|
||||
Wan22FunControlToVideo,
|
||||
WanFunInpaintToVideo,
|
||||
WanFirstLastFrameToVideo,
|
||||
WanVaceToVideo,
|
||||
TrimVideoLatent,
|
||||
WanCameraImageToVideo,
|
||||
WanPhantomSubjectToVideo,
|
||||
Wan22ImageToVideoLatent,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> WanExtension:
|
||||
return WanExtension()
|
||||
|
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.49"
|
||||
__version__ = "0.3.50"
|
||||
|
1
nodes.py
1
nodes.py
@@ -2320,6 +2320,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_camera_trajectory.py",
|
||||
"nodes_edit_model.py",
|
||||
"nodes_tcfg.py",
|
||||
"nodes_context_windows.py",
|
||||
"nodes_assets_test.py",
|
||||
]
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.49"
|
||||
version = "0.3.50"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.24.4
|
||||
comfyui-workflow-templates==0.1.53
|
||||
comfyui-frontend-package==1.25.9
|
||||
comfyui-workflow-templates==0.1.60
|
||||
comfyui-embedded-docs==0.2.6
|
||||
torch
|
||||
torchsde
|
||||
@@ -20,12 +20,12 @@ tqdm
|
||||
psutil
|
||||
alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
blake3
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
spandrel
|
||||
soundfile
|
||||
av>=14.2.0
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
|
Reference in New Issue
Block a user