mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 16:26:39 +00:00
Merge branch 'v3-definition' into v3-definition-wip
This commit is contained in:
commit
9bd3faaf1f
@ -294,6 +294,13 @@ For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a
|
|||||||
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
|
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
|
||||||
3. Launch ComfyUI by running `python main.py`
|
3. Launch ComfyUI by running `python main.py`
|
||||||
|
|
||||||
|
#### Iluvatar Corex
|
||||||
|
|
||||||
|
For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method:
|
||||||
|
|
||||||
|
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
|
||||||
|
2. Launch ComfyUI by running `python main.py`
|
||||||
|
|
||||||
# Running
|
# Running
|
||||||
|
|
||||||
```python main.py```
|
```python main.py```
|
||||||
|
@ -29,18 +29,48 @@ def frontend_install_warning_message():
|
|||||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
def parse_version(version: str) -> tuple[int, int, int]:
|
||||||
|
return tuple(map(int, version.split(".")))
|
||||||
|
|
||||||
|
def is_valid_version(version: str) -> bool:
|
||||||
|
"""Validate if a string is a valid semantic version (X.Y.Z format)."""
|
||||||
|
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
|
||||||
|
return bool(re.match(pattern, version))
|
||||||
|
|
||||||
|
def get_installed_frontend_version():
|
||||||
|
"""Get the currently installed frontend package version."""
|
||||||
|
frontend_version_str = version("comfyui-frontend-package")
|
||||||
|
return frontend_version_str
|
||||||
|
|
||||||
|
def get_required_frontend_version():
|
||||||
|
"""Get the required frontend version from requirements.txt."""
|
||||||
|
try:
|
||||||
|
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith("comfyui-frontend-package=="):
|
||||||
|
version_str = line.split("==")[-1]
|
||||||
|
if not is_valid_version(version_str):
|
||||||
|
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
||||||
|
return None
|
||||||
|
return version_str
|
||||||
|
logging.error("comfyui-frontend-package not found in requirements.txt")
|
||||||
|
return None
|
||||||
|
except FileNotFoundError:
|
||||||
|
logging.error("requirements.txt not found. Cannot determine required frontend version.")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error reading requirements.txt: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def check_frontend_version():
|
def check_frontend_version():
|
||||||
"""Check if the frontend version is up to date."""
|
"""Check if the frontend version is up to date."""
|
||||||
|
|
||||||
def parse_version(version: str) -> tuple[int, int, int]:
|
|
||||||
return tuple(map(int, version.split(".")))
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
frontend_version_str = version("comfyui-frontend-package")
|
frontend_version_str = get_installed_frontend_version()
|
||||||
frontend_version = parse_version(frontend_version_str)
|
frontend_version = parse_version(frontend_version_str)
|
||||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
required_frontend_str = get_required_frontend_version()
|
||||||
required_frontend = parse_version(f.readline().split("=")[-1])
|
required_frontend = parse_version(required_frontend_str)
|
||||||
if frontend_version < required_frontend:
|
if frontend_version < required_frontend:
|
||||||
app.logger.log_startup_warning(
|
app.logger.log_startup_warning(
|
||||||
f"""
|
f"""
|
||||||
@ -168,6 +198,11 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
|||||||
class FrontendManager:
|
class FrontendManager:
|
||||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_required_frontend_version(cls) -> str:
|
||||||
|
"""Get the required frontend package version."""
|
||||||
|
return get_required_frontend_version()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_frontend_path(cls) -> str:
|
def default_frontend_path(cls) -> str:
|
||||||
try:
|
try:
|
||||||
|
@ -1210,39 +1210,21 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
return x_next
|
return x_next
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
|
|
||||||
temp = [0]
|
|
||||||
def post_cfg_function(args):
|
|
||||||
temp[0] = args["uncond_denoised"]
|
|
||||||
return args["denoised"]
|
|
||||||
|
|
||||||
model_options = extra_args.get("model_options", {}).copy()
|
|
||||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
|
||||||
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
|
||||||
sigma_hat = sigmas[i]
|
|
||||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
|
||||||
d = to_d(x, sigma_hat, temp[0])
|
|
||||||
if callback is not None:
|
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
|
||||||
# Euler method
|
|
||||||
x = denoised + d * sigmas[i + 1]
|
|
||||||
return x
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with Euler method steps."""
|
"""Ancestral sampling with Euler method steps (CFG++)."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
|
||||||
temp = [0]
|
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
||||||
|
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
||||||
|
|
||||||
|
uncond_denoised = None
|
||||||
|
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
temp[0] = args["uncond_denoised"]
|
nonlocal uncond_denoised
|
||||||
|
uncond_denoised = args["uncond_denoised"]
|
||||||
return args["denoised"]
|
return args["denoised"]
|
||||||
|
|
||||||
model_options = extra_args.get("model_options", {}).copy()
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
@ -1251,15 +1233,33 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
|||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
d = to_d(x, sigmas[i], temp[0])
|
if sigmas[i + 1] == 0:
|
||||||
|
# Denoising step
|
||||||
|
x = denoised
|
||||||
|
else:
|
||||||
|
alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
|
||||||
|
alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
|
||||||
|
d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
|
||||||
|
|
||||||
|
# DDIM stochastic sampling
|
||||||
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
|
||||||
|
sigma_down = alpha_t * sigma_down
|
||||||
|
|
||||||
# Euler method
|
# Euler method
|
||||||
x = denoised + d * sigma_down
|
x = alpha_t * denoised + sigma_down * d
|
||||||
if sigmas[i + 1] > 0:
|
if eta > 0 and s_noise > 0:
|
||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||||
|
"""Euler method steps (CFG++)."""
|
||||||
|
return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
|
@ -128,6 +128,11 @@ try:
|
|||||||
except:
|
except:
|
||||||
mlu_available = False
|
mlu_available = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
ixuca_available = hasattr(torch, "corex")
|
||||||
|
except:
|
||||||
|
ixuca_available = False
|
||||||
|
|
||||||
if args.cpu:
|
if args.cpu:
|
||||||
cpu_state = CPUState.CPU
|
cpu_state = CPUState.CPU
|
||||||
|
|
||||||
@ -151,6 +156,12 @@ def is_mlu():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_ixuca():
|
||||||
|
global ixuca_available
|
||||||
|
if ixuca_available:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
global cpu_state
|
global cpu_state
|
||||||
@ -289,7 +300,7 @@ try:
|
|||||||
if torch_version_numeric[0] >= 2:
|
if torch_version_numeric[0] >= 2:
|
||||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if is_intel_xpu() or is_ascend_npu() or is_mlu():
|
if is_intel_xpu() or is_ascend_npu() or is_mlu() or is_ixuca():
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
except:
|
except:
|
||||||
@ -1045,6 +1056,8 @@ def xformers_enabled():
|
|||||||
return False
|
return False
|
||||||
if is_mlu():
|
if is_mlu():
|
||||||
return False
|
return False
|
||||||
|
if is_ixuca():
|
||||||
|
return False
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
return False
|
return False
|
||||||
return XFORMERS_IS_AVAILABLE
|
return XFORMERS_IS_AVAILABLE
|
||||||
@ -1080,6 +1093,8 @@ def pytorch_attention_flash_attention():
|
|||||||
return True
|
return True
|
||||||
if is_amd():
|
if is_amd():
|
||||||
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
||||||
|
if is_ixuca():
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def force_upcast_attention_dtype():
|
def force_upcast_attention_dtype():
|
||||||
@ -1205,6 +1220,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if is_mlu():
|
if is_mlu():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if is_ixuca():
|
||||||
|
return True
|
||||||
|
|
||||||
if torch.version.hip:
|
if torch.version.hip:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -1268,6 +1286,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if is_ascend_npu():
|
if is_ascend_npu():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if is_ixuca():
|
||||||
|
return True
|
||||||
|
|
||||||
if is_amd():
|
if is_amd():
|
||||||
arch = torch.cuda.get_device_properties(device).gcnArchName
|
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||||
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
|
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
|
||||||
|
@ -15,9 +15,20 @@ adapters: list[type[WeightAdapterBase]] = [
|
|||||||
OFTAdapter,
|
OFTAdapter,
|
||||||
BOFTAdapter,
|
BOFTAdapter,
|
||||||
]
|
]
|
||||||
|
adapter_maps: dict[str, type[WeightAdapterBase]] = {
|
||||||
|
"LoRA": LoRAAdapter,
|
||||||
|
"LoHa": LoHaAdapter,
|
||||||
|
"LoKr": LoKrAdapter,
|
||||||
|
"OFT": OFTAdapter,
|
||||||
|
## We disable not implemented algo for now
|
||||||
|
# "GLoRA": GLoRAAdapter,
|
||||||
|
# "BOFT": BOFTAdapter,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"WeightAdapterBase",
|
"WeightAdapterBase",
|
||||||
"WeightAdapterTrainBase",
|
"WeightAdapterTrainBase",
|
||||||
"adapters"
|
"adapters",
|
||||||
|
"adapter_maps",
|
||||||
] + [a.__name__ for a in adapters]
|
] + [a.__name__ for a in adapters]
|
||||||
|
@ -133,3 +133,43 @@ def tucker_weight_from_conv(up, down, mid):
|
|||||||
def tucker_weight(wa, wb, t):
|
def tucker_weight(wa, wb, t):
|
||||||
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
|
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
|
||||||
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
|
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
|
||||||
|
|
||||||
|
|
||||||
|
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||||
|
second value is higher or equal than first value.
|
||||||
|
|
||||||
|
examples)
|
||||||
|
factor
|
||||||
|
-1 2 4 8 16 ...
|
||||||
|
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
|
||||||
|
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
|
||||||
|
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
|
||||||
|
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
|
||||||
|
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
|
||||||
|
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
|
||||||
|
"""
|
||||||
|
|
||||||
|
if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2:
|
||||||
|
m = factor
|
||||||
|
n = dimension // factor
|
||||||
|
if m > n:
|
||||||
|
n, m = m, n
|
||||||
|
return m, n
|
||||||
|
if factor < 0:
|
||||||
|
factor = dimension
|
||||||
|
m, n = 1, dimension
|
||||||
|
length = m + n
|
||||||
|
while m < n:
|
||||||
|
new_m = m + 1
|
||||||
|
while dimension % new_m != 0:
|
||||||
|
new_m += 1
|
||||||
|
new_n = dimension // new_m
|
||||||
|
if new_m + new_n > length or new_m > factor:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
m, n = new_m, new_n
|
||||||
|
if m > n:
|
||||||
|
n, m = m, n
|
||||||
|
return m, n
|
||||||
|
@ -3,7 +3,120 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class HadaWeight(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
|
||||||
|
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
|
||||||
|
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
|
||||||
|
return diff_weight
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
|
||||||
|
grad_out = grad_out * scale
|
||||||
|
temp = grad_out * (w2u @ w2d)
|
||||||
|
grad_w1u = temp @ w1d.T
|
||||||
|
grad_w1d = w1u.T @ temp
|
||||||
|
|
||||||
|
temp = grad_out * (w1u @ w1d)
|
||||||
|
grad_w2u = temp @ w2d.T
|
||||||
|
grad_w2d = w2u.T @ temp
|
||||||
|
|
||||||
|
del temp
|
||||||
|
return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None
|
||||||
|
|
||||||
|
|
||||||
|
class HadaWeightTucker(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)):
|
||||||
|
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
|
||||||
|
|
||||||
|
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
|
||||||
|
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
|
||||||
|
|
||||||
|
return rebuild1 * rebuild2 * scale
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
|
||||||
|
grad_out = grad_out * scale
|
||||||
|
|
||||||
|
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
|
||||||
|
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
|
||||||
|
|
||||||
|
grad_w = rebuild * grad_out
|
||||||
|
del rebuild
|
||||||
|
|
||||||
|
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||||
|
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
|
||||||
|
del grad_w, temp
|
||||||
|
|
||||||
|
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
|
||||||
|
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
|
||||||
|
del grad_temp
|
||||||
|
|
||||||
|
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
|
||||||
|
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
|
||||||
|
|
||||||
|
grad_w = rebuild * grad_out
|
||||||
|
del rebuild
|
||||||
|
|
||||||
|
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||||
|
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
|
||||||
|
del grad_w, temp
|
||||||
|
|
||||||
|
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
|
||||||
|
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
|
||||||
|
del grad_temp
|
||||||
|
return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None
|
||||||
|
|
||||||
|
|
||||||
|
class LohaDiff(WeightAdapterTrainBase):
|
||||||
|
def __init__(self, weights):
|
||||||
|
super().__init__()
|
||||||
|
# Unpack weights tuple from LoHaAdapter
|
||||||
|
w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights
|
||||||
|
|
||||||
|
# Create trainable parameters
|
||||||
|
self.hada_w1_a = torch.nn.Parameter(w1a)
|
||||||
|
self.hada_w1_b = torch.nn.Parameter(w1b)
|
||||||
|
self.hada_w2_a = torch.nn.Parameter(w2a)
|
||||||
|
self.hada_w2_b = torch.nn.Parameter(w2b)
|
||||||
|
|
||||||
|
self.use_tucker = False
|
||||||
|
if t1 is not None and t2 is not None:
|
||||||
|
self.use_tucker = True
|
||||||
|
self.hada_t1 = torch.nn.Parameter(t1)
|
||||||
|
self.hada_t2 = torch.nn.Parameter(t2)
|
||||||
|
else:
|
||||||
|
# Keep the attributes for consistent access
|
||||||
|
self.hada_t1 = None
|
||||||
|
self.hada_t2 = None
|
||||||
|
|
||||||
|
# Store rank and non-trainable alpha
|
||||||
|
self.rank = w1b.shape[0]
|
||||||
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||||
|
|
||||||
|
def __call__(self, w):
|
||||||
|
org_dtype = w.dtype
|
||||||
|
|
||||||
|
scale = self.alpha / self.rank
|
||||||
|
if self.use_tucker:
|
||||||
|
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
|
||||||
|
else:
|
||||||
|
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||||
|
|
||||||
|
# Add the scaled difference to the original weight
|
||||||
|
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
|
||||||
|
|
||||||
|
return weight.to(org_dtype)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
"""Calculates memory usage of the trainable parameters."""
|
||||||
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||||
|
|
||||||
|
|
||||||
class LoHaAdapter(WeightAdapterBase):
|
class LoHaAdapter(WeightAdapterBase):
|
||||||
@ -13,6 +126,25 @@ class LoHaAdapter(WeightAdapterBase):
|
|||||||
self.loaded_keys = loaded_keys
|
self.loaded_keys = loaded_keys
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
|
out_dim = weight.shape[0]
|
||||||
|
in_dim = weight.shape[1:].numel()
|
||||||
|
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||||
|
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||||
|
torch.nn.init.normal_(mat1, 0.1)
|
||||||
|
torch.nn.init.constant_(mat2, 0.0)
|
||||||
|
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||||
|
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||||
|
torch.nn.init.normal_(mat3, 0.1)
|
||||||
|
torch.nn.init.normal_(mat4, 0.01)
|
||||||
|
return LohaDiff(
|
||||||
|
(mat1, mat2, alpha, mat3, mat4, None, None, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_train(self):
|
||||||
|
return LohaDiff(self.weights)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
|
@ -3,7 +3,77 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import (
|
||||||
|
WeightAdapterBase,
|
||||||
|
WeightAdapterTrainBase,
|
||||||
|
weight_decompose,
|
||||||
|
factorization,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LokrDiff(WeightAdapterTrainBase):
|
||||||
|
def __init__(self, weights):
|
||||||
|
super().__init__()
|
||||||
|
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
|
||||||
|
self.use_tucker = False
|
||||||
|
if lokr_w1_a is not None:
|
||||||
|
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
|
||||||
|
rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1]
|
||||||
|
self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a)
|
||||||
|
self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b)
|
||||||
|
self.w1_rebuild = True
|
||||||
|
self.ranka = rank_a
|
||||||
|
|
||||||
|
if lokr_w2_a is not None:
|
||||||
|
_, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1]
|
||||||
|
rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1]
|
||||||
|
self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a)
|
||||||
|
self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b)
|
||||||
|
if lokr_t2 is not None:
|
||||||
|
self.use_tucker = True
|
||||||
|
self.lokr_t2 = torch.nn.Parameter(lokr_t2)
|
||||||
|
self.w2_rebuild = True
|
||||||
|
self.rankb = rank_b
|
||||||
|
|
||||||
|
if lokr_w1 is not None:
|
||||||
|
self.lokr_w1 = torch.nn.Parameter(lokr_w1)
|
||||||
|
self.w1_rebuild = False
|
||||||
|
|
||||||
|
if lokr_w2 is not None:
|
||||||
|
self.lokr_w2 = torch.nn.Parameter(lokr_w2)
|
||||||
|
self.w2_rebuild = False
|
||||||
|
|
||||||
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w1(self):
|
||||||
|
if self.w1_rebuild:
|
||||||
|
return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka)
|
||||||
|
else:
|
||||||
|
return self.lokr_w1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w2(self):
|
||||||
|
if self.w2_rebuild:
|
||||||
|
if self.use_tucker:
|
||||||
|
w2 = torch.einsum(
|
||||||
|
'i j k l, j r, i p -> p r k l',
|
||||||
|
self.lokr_t2,
|
||||||
|
self.lokr_w2_b,
|
||||||
|
self.lokr_w2_a
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
||||||
|
return w2 * (self.alpha / self.rankb)
|
||||||
|
else:
|
||||||
|
return self.lokr_w2
|
||||||
|
|
||||||
|
def __call__(self, w):
|
||||||
|
diff = torch.kron(self.w1, self.w2)
|
||||||
|
return w + diff.reshape(w.shape).to(w)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||||
|
|
||||||
|
|
||||||
class LoKrAdapter(WeightAdapterBase):
|
class LoKrAdapter(WeightAdapterBase):
|
||||||
@ -13,6 +83,20 @@ class LoKrAdapter(WeightAdapterBase):
|
|||||||
self.loaded_keys = loaded_keys
|
self.loaded_keys = loaded_keys
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
|
out_dim = weight.shape[0]
|
||||||
|
in_dim = weight.shape[1:].numel()
|
||||||
|
out1, out2 = factorization(out_dim, rank)
|
||||||
|
in1, in2 = factorization(in_dim, rank)
|
||||||
|
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
|
||||||
|
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
|
||||||
|
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
|
||||||
|
torch.nn.init.constant_(mat1, 0.0)
|
||||||
|
return LokrDiff(
|
||||||
|
(mat1, mat2, alpha, None, None, None, None, None, None)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
|
@ -3,7 +3,58 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
|
||||||
|
|
||||||
|
|
||||||
|
class OFTDiff(WeightAdapterTrainBase):
|
||||||
|
def __init__(self, weights):
|
||||||
|
super().__init__()
|
||||||
|
# Unpack weights tuple from LoHaAdapter
|
||||||
|
blocks, rescale, alpha, _ = weights
|
||||||
|
|
||||||
|
# Create trainable parameters
|
||||||
|
self.oft_blocks = torch.nn.Parameter(blocks)
|
||||||
|
if rescale is not None:
|
||||||
|
self.rescale = torch.nn.Parameter(rescale)
|
||||||
|
self.rescaled = True
|
||||||
|
else:
|
||||||
|
self.rescaled = False
|
||||||
|
self.block_num, self.block_size, _ = blocks.shape
|
||||||
|
self.constraint = float(alpha)
|
||||||
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||||
|
|
||||||
|
def __call__(self, w):
|
||||||
|
org_dtype = w.dtype
|
||||||
|
I = torch.eye(self.block_size, device=self.oft_blocks.device)
|
||||||
|
|
||||||
|
## generate r
|
||||||
|
# for Q = -Q^T
|
||||||
|
q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
|
||||||
|
normed_q = q
|
||||||
|
if self.constraint:
|
||||||
|
q_norm = torch.norm(q) + 1e-8
|
||||||
|
if q_norm > self.constraint:
|
||||||
|
normed_q = q * self.constraint / q_norm
|
||||||
|
# use float() to prevent unsupported type
|
||||||
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||||
|
|
||||||
|
## Apply chunked matmul on weight
|
||||||
|
_, *shape = w.shape
|
||||||
|
org_weight = w.to(dtype=r.dtype)
|
||||||
|
org_weight = org_weight.unflatten(0, (self.block_num, self.block_size))
|
||||||
|
# Init R=0, so add I on it to ensure the output of step0 is original model output
|
||||||
|
weight = torch.einsum(
|
||||||
|
"k n m, k n ... -> k m ...",
|
||||||
|
r,
|
||||||
|
org_weight,
|
||||||
|
).flatten(0, 1)
|
||||||
|
if self.rescaled:
|
||||||
|
weight = self.rescale * weight
|
||||||
|
return weight.to(org_dtype)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
"""Calculates memory usage of the trainable parameters."""
|
||||||
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||||
|
|
||||||
|
|
||||||
class OFTAdapter(WeightAdapterBase):
|
class OFTAdapter(WeightAdapterBase):
|
||||||
@ -13,6 +64,18 @@ class OFTAdapter(WeightAdapterBase):
|
|||||||
self.loaded_keys = loaded_keys
|
self.loaded_keys = loaded_keys
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
|
out_dim = weight.shape[0]
|
||||||
|
block_size, block_num = factorization(out_dim, rank)
|
||||||
|
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
|
||||||
|
return OFTDiff(
|
||||||
|
(block, None, alpha, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_train(self):
|
||||||
|
return OFTDiff(self.weights)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
@ -60,6 +123,8 @@ class OFTAdapter(WeightAdapterBase):
|
|||||||
blocks = v[0]
|
blocks = v[0]
|
||||||
rescale = v[1]
|
rescale = v[1]
|
||||||
alpha = v[2]
|
alpha = v[2]
|
||||||
|
if alpha is None:
|
||||||
|
alpha = 0
|
||||||
dora_scale = v[3]
|
dora_scale = v[3]
|
||||||
|
|
||||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||||
|
@ -656,9 +656,34 @@ class Accumulation(ComfyTypeIO):
|
|||||||
accum: list[Any]
|
accum: list[Any]
|
||||||
Type = AccumulationDict
|
Type = AccumulationDict
|
||||||
|
|
||||||
|
|
||||||
@comfytype(io_type="LOAD3D_CAMERA")
|
@comfytype(io_type="LOAD3D_CAMERA")
|
||||||
class Load3DCamera(ComfyTypeIO):
|
class Load3DCamera(ComfyTypeIO):
|
||||||
Type = Any # TODO: figure out type for this; in code, only described as image['camera_info'], gotten from a LOAD_3D or LOAD_3D_ANIMATION type
|
class CameraInfo(TypedDict):
|
||||||
|
position: dict[str, float | int]
|
||||||
|
target: dict[str, float | int]
|
||||||
|
zoom: int
|
||||||
|
cameraType: str
|
||||||
|
|
||||||
|
Type = CameraInfo
|
||||||
|
|
||||||
|
|
||||||
|
@comfytype(io_type="LOAD_3D")
|
||||||
|
class Load3D(ComfyTypeIO):
|
||||||
|
"""3D models are stored as a dictionary."""
|
||||||
|
class Model3DDict(TypedDict):
|
||||||
|
image: str
|
||||||
|
mask: str
|
||||||
|
normal: str
|
||||||
|
camera_info: Load3DCamera.CameraInfo
|
||||||
|
recording: NotRequired[str]
|
||||||
|
|
||||||
|
Type = Model3DDict
|
||||||
|
|
||||||
|
|
||||||
|
@comfytype(io_type="LOAD_3D_ANIMATION")
|
||||||
|
class Load3DAnimation(Load3D):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
@comfytype(io_type="PHOTOMAKER")
|
@comfytype(io_type="PHOTOMAKER")
|
||||||
|
@ -475,11 +475,12 @@ class PreviewVideo(_UIOutput):
|
|||||||
|
|
||||||
|
|
||||||
class PreviewUI3D(_UIOutput):
|
class PreviewUI3D(_UIOutput):
|
||||||
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
def __init__(self, model_file, camera_info, **kwargs):
|
||||||
self.values = values
|
self.model_file = model_file
|
||||||
|
self.camera_info = camera_info
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return {"3d": self.values}
|
return {"result": [self.model_file, self.camera_info]}
|
||||||
|
|
||||||
|
|
||||||
class PreviewText(_UIOutput):
|
class PreviewText(_UIOutput):
|
||||||
|
@ -20,7 +20,7 @@ import folder_paths
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy.comfy_types.node_typing import IO
|
from comfy.comfy_types.node_typing import IO
|
||||||
from comfy.weight_adapter import adapters
|
from comfy.weight_adapter import adapters, adapter_maps
|
||||||
|
|
||||||
|
|
||||||
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||||
@ -39,13 +39,13 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
|
|||||||
|
|
||||||
|
|
||||||
class TrainSampler(comfy.samplers.Sampler):
|
class TrainSampler(comfy.samplers.Sampler):
|
||||||
|
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.loss_callback = loss_callback
|
self.loss_callback = loss_callback
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.total_steps = total_steps
|
self.total_steps = total_steps
|
||||||
|
self.grad_acc = grad_acc
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.training_dtype = training_dtype
|
self.training_dtype = training_dtype
|
||||||
|
|
||||||
@ -92,6 +92,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self.loss_callback(loss.item())
|
self.loss_callback(loss.item())
|
||||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||||
|
|
||||||
|
if (i+1) % self.grad_acc == 0:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -419,6 +420,16 @@ class TrainLoraNode:
|
|||||||
"tooltip": "The batch size to use for training.",
|
"tooltip": "The batch size to use for training.",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
"grad_accumulation_steps": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 1,
|
||||||
|
"min": 1,
|
||||||
|
"max": 1024,
|
||||||
|
"step": 1,
|
||||||
|
"tooltip": "The number of gradient accumulation steps to use for training.",
|
||||||
|
}
|
||||||
|
),
|
||||||
"steps": (
|
"steps": (
|
||||||
IO.INT,
|
IO.INT,
|
||||||
{
|
{
|
||||||
@ -478,6 +489,17 @@ class TrainLoraNode:
|
|||||||
["bf16", "fp32"],
|
["bf16", "fp32"],
|
||||||
{"default": "bf16", "tooltip": "The dtype to use for lora."},
|
{"default": "bf16", "tooltip": "The dtype to use for lora."},
|
||||||
),
|
),
|
||||||
|
"algorithm": (
|
||||||
|
list(adapter_maps.keys()),
|
||||||
|
{"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."},
|
||||||
|
),
|
||||||
|
"gradient_checkpointing": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Use gradient checkpointing for training.",
|
||||||
|
}
|
||||||
|
),
|
||||||
"existing_lora": (
|
"existing_lora": (
|
||||||
folder_paths.get_filename_list("loras") + ["[None]"],
|
folder_paths.get_filename_list("loras") + ["[None]"],
|
||||||
{
|
{
|
||||||
@ -501,6 +523,7 @@ class TrainLoraNode:
|
|||||||
positive,
|
positive,
|
||||||
batch_size,
|
batch_size,
|
||||||
steps,
|
steps,
|
||||||
|
grad_accumulation_steps,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
rank,
|
rank,
|
||||||
optimizer,
|
optimizer,
|
||||||
@ -508,6 +531,8 @@ class TrainLoraNode:
|
|||||||
seed,
|
seed,
|
||||||
training_dtype,
|
training_dtype,
|
||||||
lora_dtype,
|
lora_dtype,
|
||||||
|
algorithm,
|
||||||
|
gradient_checkpointing,
|
||||||
existing_lora,
|
existing_lora,
|
||||||
):
|
):
|
||||||
mp = model.clone()
|
mp = model.clone()
|
||||||
@ -558,10 +583,8 @@ class TrainLoraNode:
|
|||||||
if existing_adapter is not None:
|
if existing_adapter is not None:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# If no existing adapter found, use LoRA
|
|
||||||
# We will add algo option in the future
|
|
||||||
existing_adapter = None
|
existing_adapter = None
|
||||||
adapter_cls = adapters[0]
|
adapter_cls = adapter_maps[algorithm]
|
||||||
|
|
||||||
if existing_adapter is not None:
|
if existing_adapter is not None:
|
||||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||||
@ -615,6 +638,7 @@ class TrainLoraNode:
|
|||||||
criterion = torch.nn.SmoothL1Loss()
|
criterion = torch.nn.SmoothL1Loss()
|
||||||
|
|
||||||
# setup models
|
# setup models
|
||||||
|
if gradient_checkpointing:
|
||||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||||
patch(m)
|
patch(m)
|
||||||
mp.model.requires_grad_(False)
|
mp.model.requires_grad_(False)
|
||||||
@ -629,7 +653,8 @@ class TrainLoraNode:
|
|||||||
optimizer,
|
optimizer,
|
||||||
loss_callback=loss_callback,
|
loss_callback=loss_callback,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
total_steps=steps,
|
grad_acc=grad_accumulation_steps,
|
||||||
|
total_steps=steps*grad_accumulation_steps,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
training_dtype=dtype
|
training_dtype=dtype
|
||||||
)
|
)
|
||||||
|
@ -19,14 +19,14 @@ class ConditioningStableAudio(io.ComfyNode):
|
|||||||
node_id="ConditioningStableAudio_V3",
|
node_id="ConditioningStableAudio_V3",
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input(id="positive"),
|
io.Conditioning.Input("positive"),
|
||||||
io.Conditioning.Input(id="negative"),
|
io.Conditioning.Input("negative"),
|
||||||
io.Float.Input(id="seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
|
io.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
|
||||||
io.Float.Input(id="seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
|
io.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Conditioning.Output(id="positive_out", display_name="positive"),
|
io.Conditioning.Output(display_name="positive"),
|
||||||
io.Conditioning.Output(id="negative_out", display_name="negative"),
|
io.Conditioning.Output(display_name="negative"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -49,7 +49,7 @@ class EmptyLatentAudio(io.ComfyNode):
|
|||||||
node_id="EmptyLatentAudio_V3",
|
node_id="EmptyLatentAudio_V3",
|
||||||
category="latent/audio",
|
category="latent/audio",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Float.Input(id="seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
io.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
||||||
io.Int.Input(
|
io.Int.Input(
|
||||||
id="batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
id="batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||||
),
|
),
|
||||||
@ -200,8 +200,8 @@ class VAEDecodeAudio(io.ComfyNode):
|
|||||||
node_id="VAEDecodeAudio_V3",
|
node_id="VAEDecodeAudio_V3",
|
||||||
category="latent/audio",
|
category="latent/audio",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Latent.Input(id="samples"),
|
io.Latent.Input("samples"),
|
||||||
io.Vae.Input(id="vae"),
|
io.Vae.Input("vae"),
|
||||||
],
|
],
|
||||||
outputs=[io.Audio.Output()],
|
outputs=[io.Audio.Output()],
|
||||||
)
|
)
|
||||||
@ -222,8 +222,8 @@ class VAEEncodeAudio(io.ComfyNode):
|
|||||||
node_id="VAEEncodeAudio_V3",
|
node_id="VAEEncodeAudio_V3",
|
||||||
category="latent/audio",
|
category="latent/audio",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Audio.Input(id="audio"),
|
io.Audio.Input("audio"),
|
||||||
io.Vae.Input(id="vae"),
|
io.Vae.Input("vae"),
|
||||||
],
|
],
|
||||||
outputs=[io.Latent.Output()],
|
outputs=[io.Latent.Output()],
|
||||||
)
|
)
|
||||||
|
@ -13,7 +13,7 @@ class DifferentialDiffusion(io.ComfyNode):
|
|||||||
display_name="Differential Diffusion _V3",
|
display_name="Differential Diffusion _V3",
|
||||||
category="_for_testing",
|
category="_for_testing",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input(id="model"),
|
io.Model.Input("model"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(),
|
io.Model.Output(),
|
||||||
|
@ -32,10 +32,10 @@ class CLIPTextEncodeFlux(io.ComfyNode):
|
|||||||
node_id="CLIPTextEncodeFlux_V3",
|
node_id="CLIPTextEncodeFlux_V3",
|
||||||
category="advanced/conditioning/flux",
|
category="advanced/conditioning/flux",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Clip.Input(id="clip"),
|
io.Clip.Input("clip"),
|
||||||
io.String.Input(id="clip_l", multiline=True, dynamic_prompts=True),
|
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
||||||
io.String.Input(id="t5xxl", multiline=True, dynamic_prompts=True),
|
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
|
||||||
io.Float.Input(id="guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Conditioning.Output(),
|
io.Conditioning.Output(),
|
||||||
@ -58,7 +58,7 @@ class FluxDisableGuidance(io.ComfyNode):
|
|||||||
category="advanced/conditioning/flux",
|
category="advanced/conditioning/flux",
|
||||||
description="This node completely disables the guidance embed on Flux and Flux like models",
|
description="This node completely disables the guidance embed on Flux and Flux like models",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input(id="conditioning"),
|
io.Conditioning.Input("conditioning"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Conditioning.Output(),
|
io.Conditioning.Output(),
|
||||||
@ -78,8 +78,8 @@ class FluxGuidance(io.ComfyNode):
|
|||||||
node_id="FluxGuidance_V3",
|
node_id="FluxGuidance_V3",
|
||||||
category="advanced/conditioning/flux",
|
category="advanced/conditioning/flux",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input(id="conditioning"),
|
io.Conditioning.Input("conditioning"),
|
||||||
io.Float.Input(id="guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Conditioning.Output(),
|
io.Conditioning.Output(),
|
||||||
@ -100,7 +100,7 @@ class FluxKontextImageScale(io.ComfyNode):
|
|||||||
category="advanced/conditioning/flux",
|
category="advanced/conditioning/flux",
|
||||||
description="This node resizes the image to one that is more optimal for flux kontext.",
|
description="This node resizes the image to one that is more optimal for flux kontext.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input(id="image"),
|
io.Image.Input("image"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output(),
|
io.Image.Output(),
|
||||||
|
@ -35,11 +35,11 @@ class FreeU(io.ComfyNode):
|
|||||||
node_id="FreeU_V3",
|
node_id="FreeU_V3",
|
||||||
category="model_patches/unet",
|
category="model_patches/unet",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input(id="model"),
|
io.Model.Input("model"),
|
||||||
io.Float.Input(id="b1", default=1.1, min=0.0, max=10.0, step=0.01),
|
io.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01),
|
||||||
io.Float.Input(id="b2", default=1.2, min=0.0, max=10.0, step=0.01),
|
io.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01),
|
||||||
io.Float.Input(id="s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
io.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
||||||
io.Float.Input(id="s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
io.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(),
|
io.Model.Output(),
|
||||||
@ -80,11 +80,11 @@ class FreeU_V2(io.ComfyNode):
|
|||||||
node_id="FreeU_V2_V3",
|
node_id="FreeU_V2_V3",
|
||||||
category="model_patches/unet",
|
category="model_patches/unet",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input(id="model"),
|
io.Model.Input("model"),
|
||||||
io.Float.Input(id="b1", default=1.3, min=0.0, max=10.0, step=0.01),
|
io.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01),
|
||||||
io.Float.Input(id="b2", default=1.4, min=0.0, max=10.0, step=0.01),
|
io.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01),
|
||||||
io.Float.Input(id="s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
io.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
||||||
io.Float.Input(id="s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
io.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(),
|
io.Model.Output(),
|
||||||
|
@ -65,12 +65,12 @@ class FreSca(io.ComfyNode):
|
|||||||
category="_for_testing",
|
category="_for_testing",
|
||||||
description="Applies frequency-dependent scaling to the guidance",
|
description="Applies frequency-dependent scaling to the guidance",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input(id="model"),
|
io.Model.Input("model"),
|
||||||
io.Float.Input(id="scale_low", default=1.0, min=0, max=10, step=0.01,
|
io.Float.Input("scale_low", default=1.0, min=0, max=10, step=0.01,
|
||||||
tooltip="Scaling factor for low-frequency components"),
|
tooltip="Scaling factor for low-frequency components"),
|
||||||
io.Float.Input(id="scale_high", default=1.25, min=0, max=10, step=0.01,
|
io.Float.Input("scale_high", default=1.25, min=0, max=10, step=0.01,
|
||||||
tooltip="Scaling factor for high-frequency components"),
|
tooltip="Scaling factor for high-frequency components"),
|
||||||
io.Int.Input(id="freq_cutoff", default=20, min=1, max=10000, step=1,
|
io.Int.Input("freq_cutoff", default=20, min=1, max=10000, step=1,
|
||||||
tooltip="Number of frequency indices around center to consider as low-frequency"),
|
tooltip="Number of frequency indices around center to consider as low-frequency"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
|
@ -343,9 +343,9 @@ class GITSScheduler(io.ComfyNode):
|
|||||||
node_id="GITSScheduler_V3",
|
node_id="GITSScheduler_V3",
|
||||||
category="sampling/custom_sampling/schedulers",
|
category="sampling/custom_sampling/schedulers",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Float.Input(id="coeff", default=1.20, min=0.80, max=1.50, step=0.05),
|
io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05),
|
||||||
io.Int.Input(id="steps", default=10, min=2, max=1000),
|
io.Int.Input("steps", default=10, min=2, max=1000),
|
||||||
io.Float.Input(id="denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Sigmas.Output(),
|
io.Sigmas.Output(),
|
||||||
|
167
comfy_extras/v3/nodes_hunyuan.py
Normal file
167
comfy_extras/v3/nodes_hunyuan.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import node_helpers
|
||||||
|
import nodes
|
||||||
|
from comfy_api.v3 import io
|
||||||
|
|
||||||
|
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||||
|
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||||
|
"1. The main content and theme of the video."
|
||||||
|
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||||
|
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||||
|
"4. background environment, light, style and atmosphere."
|
||||||
|
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CLIPTextEncodeHunyuanDiT_V3",
|
||||||
|
category="advanced/conditioning",
|
||||||
|
inputs=[
|
||||||
|
io.Clip.Input("clip"),
|
||||||
|
io.String.Input("bert", multiline=True, dynamic_prompts=True),
|
||||||
|
io.String.Input("mt5xl", multiline=True, dynamic_prompts=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, clip, bert, mt5xl):
|
||||||
|
tokens = clip.tokenize(bert)
|
||||||
|
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
|
||||||
|
|
||||||
|
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyHunyuanLatentVideo(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="EmptyHunyuanLatentVideo_V3",
|
||||||
|
category="latent/video",
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("length", default=25, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, width, height, length, batch_size):
|
||||||
|
latent = torch.zeros(
|
||||||
|
[batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
|
||||||
|
device=comfy.model_management.intermediate_device(),
|
||||||
|
)
|
||||||
|
return io.NodeOutput({"samples":latent})
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanImageToVideo(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="HunyuanImageToVideo_V3",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"]),
|
||||||
|
io.Image.Input("start_image", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
|
||||||
|
latent = torch.zeros(
|
||||||
|
[batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
|
||||||
|
device=comfy.model_management.intermediate_device(),
|
||||||
|
)
|
||||||
|
out_latent = {}
|
||||||
|
|
||||||
|
if start_image is not None:
|
||||||
|
start_image = comfy.utils.common_upscale(
|
||||||
|
start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center"
|
||||||
|
).movedim(1, -1)
|
||||||
|
|
||||||
|
concat_latent_image = vae.encode(start_image)
|
||||||
|
mask = torch.ones(
|
||||||
|
(1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]),
|
||||||
|
device=start_image.device,
|
||||||
|
dtype=start_image.dtype,
|
||||||
|
)
|
||||||
|
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||||
|
|
||||||
|
if guidance_type == "v1 (concat)":
|
||||||
|
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
|
||||||
|
elif guidance_type == "v2 (replace)":
|
||||||
|
cond = {'guiding_frame_index': 0}
|
||||||
|
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
|
||||||
|
out_latent["noise_mask"] = mask
|
||||||
|
elif guidance_type == "custom":
|
||||||
|
cond = {"ref_latent": concat_latent_image}
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, cond)
|
||||||
|
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return io.NodeOutput(positive, out_latent)
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="TextEncodeHunyuanVideo_ImageToVideo_V3",
|
||||||
|
category="advanced/conditioning",
|
||||||
|
inputs=[
|
||||||
|
io.Clip.Input("clip"),
|
||||||
|
io.ClipVisionOutput.Input("clip_vision_output"),
|
||||||
|
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
|
||||||
|
io.Int.Input(
|
||||||
|
"image_interleave",
|
||||||
|
default=2,
|
||||||
|
min=1,
|
||||||
|
max=512,
|
||||||
|
tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, clip, clip_vision_output, prompt, image_interleave):
|
||||||
|
tokens = clip.tokenize(
|
||||||
|
prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V,
|
||||||
|
image_embeds=clip_vision_output.mm_projected,
|
||||||
|
image_interleave=image_interleave,
|
||||||
|
)
|
||||||
|
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
CLIPTextEncodeHunyuanDiT,
|
||||||
|
EmptyHunyuanLatentVideo,
|
||||||
|
HunyuanImageToVideo,
|
||||||
|
TextEncodeHunyuanVideo_ImageToVideo,
|
||||||
|
]
|
136
comfy_extras/v3/nodes_hypernetwork.py
Normal file
136
comfy_extras/v3/nodes_hypernetwork.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
from comfy_api.v3 import io
|
||||||
|
|
||||||
|
|
||||||
|
def load_hypernetwork_patch(path, strength):
|
||||||
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||||
|
activation_func = sd.get('activation_func', 'linear')
|
||||||
|
is_layer_norm = sd.get('is_layer_norm', False)
|
||||||
|
use_dropout = sd.get('use_dropout', False)
|
||||||
|
activate_output = sd.get('activate_output', False)
|
||||||
|
last_layer_dropout = sd.get('last_layer_dropout', False)
|
||||||
|
|
||||||
|
valid_activation = {
|
||||||
|
"linear": torch.nn.Identity,
|
||||||
|
"relu": torch.nn.ReLU,
|
||||||
|
"leakyrelu": torch.nn.LeakyReLU,
|
||||||
|
"elu": torch.nn.ELU,
|
||||||
|
"swish": torch.nn.Hardswish,
|
||||||
|
"tanh": torch.nn.Tanh,
|
||||||
|
"sigmoid": torch.nn.Sigmoid,
|
||||||
|
"softsign": torch.nn.Softsign,
|
||||||
|
"mish": torch.nn.Mish,
|
||||||
|
}
|
||||||
|
|
||||||
|
logging.error(
|
||||||
|
"Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(
|
||||||
|
path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
out = {}
|
||||||
|
|
||||||
|
for d in sd:
|
||||||
|
try:
|
||||||
|
dim = int(d)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for index in [0, 1]:
|
||||||
|
attn_weights = sd[dim][index]
|
||||||
|
keys = attn_weights.keys()
|
||||||
|
|
||||||
|
linears = filter(lambda a: a.endswith(".weight"), keys)
|
||||||
|
linears = list(map(lambda a: a[:-len(".weight")], linears))
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < len(linears):
|
||||||
|
lin_name = linears[i]
|
||||||
|
last_layer = (i == (len(linears) - 1))
|
||||||
|
penultimate_layer = (i == (len(linears) - 2))
|
||||||
|
|
||||||
|
lin_weight = attn_weights['{}.weight'.format(lin_name)]
|
||||||
|
lin_bias = attn_weights['{}.bias'.format(lin_name)]
|
||||||
|
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
|
||||||
|
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
|
||||||
|
layers.append(layer)
|
||||||
|
if activation_func != "linear":
|
||||||
|
if (not last_layer) or (activate_output):
|
||||||
|
layers.append(valid_activation[activation_func]())
|
||||||
|
if is_layer_norm:
|
||||||
|
i += 1
|
||||||
|
ln_name = linears[i]
|
||||||
|
ln_weight = attn_weights['{}.weight'.format(ln_name)]
|
||||||
|
ln_bias = attn_weights['{}.bias'.format(ln_name)]
|
||||||
|
ln = torch.nn.LayerNorm(ln_weight.shape[0])
|
||||||
|
ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
|
||||||
|
layers.append(ln)
|
||||||
|
if use_dropout:
|
||||||
|
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
|
||||||
|
layers.append(torch.nn.Dropout(p=0.3))
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
output.append(torch.nn.Sequential(*layers))
|
||||||
|
out[dim] = torch.nn.ModuleList(output)
|
||||||
|
|
||||||
|
class hypernetwork_patch:
|
||||||
|
def __init__(self, hypernet, strength):
|
||||||
|
self.hypernet = hypernet
|
||||||
|
self.strength = strength
|
||||||
|
|
||||||
|
def __call__(self, q, k, v, extra_options):
|
||||||
|
dim = k.shape[-1]
|
||||||
|
if dim in self.hypernet:
|
||||||
|
hn = self.hypernet[dim]
|
||||||
|
k = k + hn[0](k) * self.strength
|
||||||
|
v = v + hn[1](v) * self.strength
|
||||||
|
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
for d in self.hypernet.keys():
|
||||||
|
self.hypernet[d] = self.hypernet[d].to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
return hypernetwork_patch(out, strength)
|
||||||
|
|
||||||
|
|
||||||
|
class HypernetworkLoader(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="HypernetworkLoader_V3",
|
||||||
|
category="loaders",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
|
||||||
|
io.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, hypernetwork_name, strength):
|
||||||
|
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||||||
|
model_hypernetwork = model.clone()
|
||||||
|
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||||
|
if patch is not None:
|
||||||
|
model_hypernetwork.set_model_attn1_patch(patch)
|
||||||
|
model_hypernetwork.set_model_attn2_patch(patch)
|
||||||
|
return io.NodeOutput(model_hypernetwork)
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
HypernetworkLoader,
|
||||||
|
]
|
95
comfy_extras/v3/nodes_hypertile.py
Normal file
95
comfy_extras/v3/nodes_hypertile.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
"""Taken from: https://github.com/tfernd/HyperTile/"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import randint
|
||||||
|
|
||||||
|
from comfy_api.v3 import io
|
||||||
|
|
||||||
|
|
||||||
|
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
||||||
|
min_value = min(min_value, value)
|
||||||
|
|
||||||
|
# All big divisors of value (inclusive)
|
||||||
|
divisors = [i for i in range(min_value, value + 1) if value % i == 0]
|
||||||
|
|
||||||
|
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
|
||||||
|
|
||||||
|
if len(ns) - 1 > 0:
|
||||||
|
idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
|
||||||
|
else:
|
||||||
|
idx = 0
|
||||||
|
|
||||||
|
return ns[idx]
|
||||||
|
|
||||||
|
|
||||||
|
class HyperTile(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="HyperTile_V3",
|
||||||
|
category="model_patches/unet",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Int.Input("tile_size", default=256, min=1, max=2048),
|
||||||
|
io.Int.Input("swap_size", default=2, min=1, max=128),
|
||||||
|
io.Int.Input("max_depth", default=0, min=0, max=10),
|
||||||
|
io.Boolean.Input("scale_depth", default=False),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, tile_size, swap_size, max_depth, scale_depth):
|
||||||
|
latent_tile_size = max(32, tile_size) // 8
|
||||||
|
temp = None
|
||||||
|
|
||||||
|
def hypertile_in(q, k, v, extra_options):
|
||||||
|
nonlocal temp
|
||||||
|
model_chans = q.shape[-2]
|
||||||
|
orig_shape = extra_options['original_shape']
|
||||||
|
apply_to = []
|
||||||
|
for i in range(max_depth + 1):
|
||||||
|
apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i)))
|
||||||
|
|
||||||
|
if model_chans in apply_to:
|
||||||
|
shape = extra_options["original_shape"]
|
||||||
|
aspect_ratio = shape[-1] / shape[-2]
|
||||||
|
|
||||||
|
hw = q.size(1)
|
||||||
|
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
||||||
|
|
||||||
|
factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1
|
||||||
|
nh = random_divisor(h, latent_tile_size * factor, swap_size)
|
||||||
|
nw = random_divisor(w, latent_tile_size * factor, swap_size)
|
||||||
|
|
||||||
|
if nh * nw > 1:
|
||||||
|
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
||||||
|
temp = (nh, nw, h, w)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def hypertile_out(out, extra_options):
|
||||||
|
nonlocal temp
|
||||||
|
if temp is not None:
|
||||||
|
nh, nw, h, w = temp
|
||||||
|
temp = None
|
||||||
|
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
|
||||||
|
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
|
||||||
|
return out
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_attn1_patch(hypertile_in)
|
||||||
|
m.set_model_attn1_output_patch(hypertile_out)
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
HyperTile,
|
||||||
|
]
|
56
comfy_extras/v3/nodes_ip2p.py
Normal file
56
comfy_extras/v3/nodes_ip2p.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy_api.v3 import io
|
||||||
|
|
||||||
|
|
||||||
|
class InstructPixToPixConditioning(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="InstructPixToPixConditioning_V3",
|
||||||
|
category="conditioning/instructpix2pix",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Image.Input("pixels"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, pixels, vae):
|
||||||
|
x = (pixels.shape[1] // 8) * 8
|
||||||
|
y = (pixels.shape[2] // 8) * 8
|
||||||
|
|
||||||
|
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||||
|
x_offset = (pixels.shape[1] % 8) // 2
|
||||||
|
y_offset = (pixels.shape[2] % 8) // 2
|
||||||
|
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
|
||||||
|
|
||||||
|
concat_latent = vae.encode(pixels)
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = torch.zeros_like(concat_latent)
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for conditioning in [positive, negative]:
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
d = t[1].copy()
|
||||||
|
d["concat_latent_image"] = concat_latent
|
||||||
|
n = [t[0], d]
|
||||||
|
c.append(n)
|
||||||
|
out.append(c)
|
||||||
|
return io.NodeOutput(out[0], out[1], out_latent)
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
InstructPixToPixConditioning,
|
||||||
|
]
|
@ -24,8 +24,8 @@ class LatentAdd(io.ComfyNode):
|
|||||||
node_id="LatentAdd_V3",
|
node_id="LatentAdd_V3",
|
||||||
category="latent/advanced",
|
category="latent/advanced",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Latent.Input(id="samples1"),
|
io.Latent.Input("samples1"),
|
||||||
io.Latent.Input(id="samples2"),
|
io.Latent.Input("samples2"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Latent.Output(),
|
io.Latent.Output(),
|
||||||
@ -52,8 +52,8 @@ class LatentApplyOperation(io.ComfyNode):
|
|||||||
category="latent/advanced/operations",
|
category="latent/advanced/operations",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Latent.Input(id="samples"),
|
io.Latent.Input("samples"),
|
||||||
io.LatentOperation.Input(id="operation"),
|
io.LatentOperation.Input("operation"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Latent.Output(),
|
io.Latent.Output(),
|
||||||
@ -77,8 +77,8 @@ class LatentApplyOperationCFG(io.ComfyNode):
|
|||||||
category="latent/advanced/operations",
|
category="latent/advanced/operations",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input(id="model"),
|
io.Model.Input("model"),
|
||||||
io.LatentOperation.Input(id="operation"),
|
io.LatentOperation.Input("operation"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(),
|
io.Model.Output(),
|
||||||
@ -108,8 +108,8 @@ class LatentBatch(io.ComfyNode):
|
|||||||
node_id="LatentBatch_V3",
|
node_id="LatentBatch_V3",
|
||||||
category="latent/batch",
|
category="latent/batch",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Latent.Input(id="samples1"),
|
io.Latent.Input("samples1"),
|
||||||
io.Latent.Input(id="samples2"),
|
io.Latent.Input("samples2"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Latent.Output(),
|
io.Latent.Output(),
|
||||||
@ -137,8 +137,8 @@ class LatentBatchSeedBehavior(io.ComfyNode):
|
|||||||
node_id="LatentBatchSeedBehavior_V3",
|
node_id="LatentBatchSeedBehavior_V3",
|
||||||
category="latent/advanced",
|
category="latent/advanced",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Latent.Input(id="samples"),
|
io.Latent.Input("samples"),
|
||||||
io.Combo.Input(id="seed_behavior", options=["random", "fixed"], default="fixed"),
|
io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Latent.Output(),
|
io.Latent.Output(),
|
||||||
@ -166,9 +166,9 @@ class LatentInterpolate(io.ComfyNode):
|
|||||||
node_id="LatentInterpolate_V3",
|
node_id="LatentInterpolate_V3",
|
||||||
category="latent/advanced",
|
category="latent/advanced",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Latent.Input(id="samples1"),
|
io.Latent.Input("samples1"),
|
||||||
io.Latent.Input(id="samples2"),
|
io.Latent.Input("samples2"),
|
||||||
io.Float.Input(id="ratio", default=1.0, min=0.0, max=1.0, step=0.01),
|
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Latent.Output(),
|
io.Latent.Output(),
|
||||||
@ -205,8 +205,8 @@ class LatentMultiply(io.ComfyNode):
|
|||||||
node_id="LatentMultiply_V3",
|
node_id="LatentMultiply_V3",
|
||||||
category="latent/advanced",
|
category="latent/advanced",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Latent.Input(id="samples"),
|
io.Latent.Input("samples"),
|
||||||
io.Float.Input(id="multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
|
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Latent.Output(),
|
io.Latent.Output(),
|
||||||
@ -230,9 +230,9 @@ class LatentOperationSharpen(io.ComfyNode):
|
|||||||
category="latent/advanced/operations",
|
category="latent/advanced/operations",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Int.Input(id="sharpen_radius", default=9, min=1, max=31, step=1),
|
io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1),
|
||||||
io.Float.Input(id="sigma", default=1.0, min=0.1, max=10.0, step=0.1),
|
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
|
||||||
io.Float.Input(id="alpha", default=0.1, min=0.0, max=5.0, step=0.01),
|
io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.LatentOperation.Output(),
|
io.LatentOperation.Output(),
|
||||||
@ -272,7 +272,7 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
|
|||||||
category="latent/advanced/operations",
|
category="latent/advanced/operations",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Float.Input(id="multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
|
io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.LatentOperation.Output(),
|
io.LatentOperation.Output(),
|
||||||
@ -306,8 +306,8 @@ class LatentSubtract(io.ComfyNode):
|
|||||||
node_id="LatentSubtract_V3",
|
node_id="LatentSubtract_V3",
|
||||||
category="latent/advanced",
|
category="latent/advanced",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Latent.Input(id="samples1"),
|
io.Latent.Input("samples1"),
|
||||||
io.Latent.Input(id="samples2"),
|
io.Latent.Input("samples2"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Latent.Output(),
|
io.Latent.Output(),
|
||||||
|
180
comfy_extras/v3/nodes_load_3d.py
Normal file
180
comfy_extras/v3/nodes_load_3d.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
import nodes
|
||||||
|
from comfy_api.input_impl import VideoFromFile
|
||||||
|
from comfy_api.v3 import io, ui
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_path(path):
|
||||||
|
return path.replace("\\", "/")
|
||||||
|
|
||||||
|
|
||||||
|
class Load3D(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
|
||||||
|
|
||||||
|
os.makedirs(input_dir, exist_ok=True)
|
||||||
|
|
||||||
|
input_path = Path(input_dir)
|
||||||
|
base_path = Path(folder_paths.get_input_directory())
|
||||||
|
|
||||||
|
files = [
|
||||||
|
normalize_path(str(file_path.relative_to(base_path)))
|
||||||
|
for file_path in input_path.rglob("*")
|
||||||
|
if file_path.suffix.lower() in {".gltf", ".glb", ".obj", ".fbx", ".stl"}
|
||||||
|
]
|
||||||
|
|
||||||
|
return io.Schema(
|
||||||
|
node_id="Load3D_V3",
|
||||||
|
display_name="Load 3D _V3",
|
||||||
|
category="3d",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("model_file", options=sorted(files), upload=io.UploadType.model),
|
||||||
|
io.Load3D.Input("image"),
|
||||||
|
io.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
||||||
|
io.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(display_name="image"),
|
||||||
|
io.Mask.Output(display_name="mask"),
|
||||||
|
io.String.Output(display_name="mesh_path"),
|
||||||
|
io.Image.Output(display_name="normal"),
|
||||||
|
io.Image.Output(display_name="lineart"),
|
||||||
|
io.Load3DCamera.Output(display_name="camera_info"),
|
||||||
|
io.Video.Output(display_name="recording_video"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model_file, image, **kwargs):
|
||||||
|
image_path = folder_paths.get_annotated_filepath(image["image"])
|
||||||
|
mask_path = folder_paths.get_annotated_filepath(image["mask"])
|
||||||
|
normal_path = folder_paths.get_annotated_filepath(image["normal"])
|
||||||
|
lineart_path = folder_paths.get_annotated_filepath(image["lineart"])
|
||||||
|
|
||||||
|
load_image_node = nodes.LoadImage()
|
||||||
|
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||||
|
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||||
|
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||||
|
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
||||||
|
|
||||||
|
video = None
|
||||||
|
if image["recording"] != "":
|
||||||
|
recording_video_path = folder_paths.get_annotated_filepath(image["recording"])
|
||||||
|
video = VideoFromFile(recording_video_path)
|
||||||
|
|
||||||
|
return io.NodeOutput(
|
||||||
|
output_image, output_mask, model_file, normal_image, lineart_image, image["camera_info"], video
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Load3DAnimation(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
|
||||||
|
|
||||||
|
os.makedirs(input_dir, exist_ok=True)
|
||||||
|
|
||||||
|
input_path = Path(input_dir)
|
||||||
|
base_path = Path(folder_paths.get_input_directory())
|
||||||
|
|
||||||
|
files = [
|
||||||
|
normalize_path(str(file_path.relative_to(base_path)))
|
||||||
|
for file_path in input_path.rglob("*")
|
||||||
|
if file_path.suffix.lower() in {".gltf", ".glb", ".fbx"}
|
||||||
|
]
|
||||||
|
|
||||||
|
return io.Schema(
|
||||||
|
node_id="Load3DAnimation_V3",
|
||||||
|
display_name="Load 3D - Animation _V3",
|
||||||
|
category="3d",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("model_file", options=sorted(files), upload=io.UploadType.model),
|
||||||
|
io.Load3DAnimation.Input("image"),
|
||||||
|
io.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
||||||
|
io.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(display_name="image"),
|
||||||
|
io.Mask.Output(display_name="mask"),
|
||||||
|
io.String.Output(display_name="mesh_path"),
|
||||||
|
io.Image.Output(display_name="normal"),
|
||||||
|
io.Load3DCamera.Output(display_name="camera_info"),
|
||||||
|
io.Video.Output(display_name="recording_video"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model_file, image, **kwargs):
|
||||||
|
image_path = folder_paths.get_annotated_filepath(image["image"])
|
||||||
|
mask_path = folder_paths.get_annotated_filepath(image["mask"])
|
||||||
|
normal_path = folder_paths.get_annotated_filepath(image["normal"])
|
||||||
|
|
||||||
|
load_image_node = nodes.LoadImage()
|
||||||
|
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||||
|
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||||
|
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||||
|
|
||||||
|
video = None
|
||||||
|
if image['recording'] != "":
|
||||||
|
recording_video_path = folder_paths.get_annotated_filepath(image["recording"])
|
||||||
|
video = VideoFromFile(recording_video_path)
|
||||||
|
|
||||||
|
return io.NodeOutput(output_image, output_mask, model_file, normal_image, image["camera_info"], video)
|
||||||
|
|
||||||
|
|
||||||
|
class Preview3D(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="Preview3D_V3", # frontend expects "Preview3D" to work
|
||||||
|
display_name="Preview 3D _V3",
|
||||||
|
category="3d",
|
||||||
|
is_experimental=True,
|
||||||
|
is_output_node=True,
|
||||||
|
inputs=[
|
||||||
|
io.String.Input("model_file", default="", multiline=False),
|
||||||
|
io.Load3DCamera.Input("camera_info", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model_file, camera_info=None):
|
||||||
|
return io.NodeOutput(ui=ui.PreviewUI3D(model_file, camera_info, cls=cls))
|
||||||
|
|
||||||
|
|
||||||
|
class Preview3DAnimation(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="Preview3DAnimation_V3", # frontend expects "Preview3DAnimation" to work
|
||||||
|
display_name="Preview 3D - Animation _V3",
|
||||||
|
category="3d",
|
||||||
|
is_experimental=True,
|
||||||
|
is_output_node=True,
|
||||||
|
inputs=[
|
||||||
|
io.String.Input("model_file", default="", multiline=False),
|
||||||
|
io.Load3DCamera.Input("camera_info", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model_file, camera_info=None):
|
||||||
|
return io.NodeOutput(ui=ui.PreviewUI3D(model_file, camera_info, cls=cls))
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
Load3D,
|
||||||
|
Load3DAnimation,
|
||||||
|
Preview3D,
|
||||||
|
Preview3DAnimation,
|
||||||
|
]
|
138
comfy_extras/v3/nodes_lora_extract.py
Normal file
138
comfy_extras/v3/nodes_lora_extract.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
from comfy_api.v3 import io
|
||||||
|
|
||||||
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
|
|
||||||
|
def extract_lora(diff, rank):
|
||||||
|
conv2d = (len(diff.shape) == 4)
|
||||||
|
kernel_size = None if not conv2d else diff.size()[2:4]
|
||||||
|
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||||
|
out_dim, in_dim = diff.size()[0:2]
|
||||||
|
rank = min(rank, in_dim, out_dim)
|
||||||
|
|
||||||
|
if conv2d:
|
||||||
|
if conv2d_3x3:
|
||||||
|
diff = diff.flatten(start_dim=1)
|
||||||
|
else:
|
||||||
|
diff = diff.squeeze()
|
||||||
|
|
||||||
|
U, S, Vh = torch.linalg.svd(diff.float())
|
||||||
|
U = U[:, :rank]
|
||||||
|
S = S[:rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
Vh = Vh[:rank, :]
|
||||||
|
|
||||||
|
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
|
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||||
|
low_val = -hi_val
|
||||||
|
|
||||||
|
U = U.clamp(low_val, hi_val)
|
||||||
|
Vh = Vh.clamp(low_val, hi_val)
|
||||||
|
if conv2d:
|
||||||
|
U = U.reshape(out_dim, rank, 1, 1)
|
||||||
|
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||||
|
return (U, Vh)
|
||||||
|
|
||||||
|
|
||||||
|
class LORAType(Enum):
|
||||||
|
STANDARD = 0
|
||||||
|
FULL_DIFF = 1
|
||||||
|
|
||||||
|
|
||||||
|
LORA_TYPES = {
|
||||||
|
"standard": LORAType.STANDARD,
|
||||||
|
"full_diff": LORAType.FULL_DIFF,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
|
||||||
|
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
||||||
|
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
||||||
|
|
||||||
|
for k in sd:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
weight_diff = sd[k]
|
||||||
|
if lora_type == LORAType.STANDARD:
|
||||||
|
if weight_diff.ndim < 2:
|
||||||
|
if bias_diff:
|
||||||
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
out = extract_lora(weight_diff, rank)
|
||||||
|
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
|
||||||
|
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
|
||||||
|
except Exception:
|
||||||
|
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
|
||||||
|
elif lora_type == LORAType.FULL_DIFF:
|
||||||
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||||
|
|
||||||
|
elif bias_diff and k.endswith(".bias"):
|
||||||
|
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||||
|
return output_sd
|
||||||
|
|
||||||
|
|
||||||
|
class LoraSave(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LoraSave_V3",
|
||||||
|
display_name="Extract and Save Lora _V3",
|
||||||
|
category="_for_testing",
|
||||||
|
is_output_node=True,
|
||||||
|
inputs=[
|
||||||
|
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
|
||||||
|
io.Int.Input("rank", default=8, min=1, max=4096, step=1),
|
||||||
|
io.Combo.Input("lora_type", options=list(LORA_TYPES.keys())),
|
||||||
|
io.Boolean.Input("bias_diff", default=True),
|
||||||
|
io.Model.Input(
|
||||||
|
id="model_diff", optional=True, tooltip="The ModelSubtract output to be converted to a lora."
|
||||||
|
),
|
||||||
|
io.Clip.Input(
|
||||||
|
id="text_encoder_diff", optional=True, tooltip="The CLIPSubtract output to be converted to a lora."
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
|
||||||
|
if model_diff is None and text_encoder_diff is None:
|
||||||
|
return io.NodeOutput()
|
||||||
|
|
||||||
|
lora_type = LORA_TYPES.get(lora_type)
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||||
|
filename_prefix, folder_paths.get_output_directory()
|
||||||
|
)
|
||||||
|
|
||||||
|
output_sd = {}
|
||||||
|
if model_diff is not None:
|
||||||
|
output_sd = calc_lora_model(
|
||||||
|
model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff
|
||||||
|
)
|
||||||
|
if text_encoder_diff is not None:
|
||||||
|
output_sd = calc_lora_model(
|
||||||
|
text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff
|
||||||
|
)
|
||||||
|
|
||||||
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
|
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||||
|
return io.NodeOutput()
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
LoraSave,
|
||||||
|
]
|
34
comfy_extras/v3/nodes_lotus.py
Normal file
34
comfy_extras/v3/nodes_lotus.py
Normal file
File diff suppressed because one or more lines are too long
@ -93,10 +93,10 @@ class EmptyLTXVLatentVideo(io.ComfyNode):
|
|||||||
node_id="EmptyLTXVLatentVideo_V3",
|
node_id="EmptyLTXVLatentVideo_V3",
|
||||||
category="latent/video/ltxv",
|
category="latent/video/ltxv",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Int.Input(id="width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||||
io.Int.Input(id="height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||||
io.Int.Input(id="length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8),
|
io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8),
|
||||||
io.Int.Input(id="batch_size", default=1, min=1, max=4096),
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Latent.Output(),
|
io.Latent.Output(),
|
||||||
@ -122,10 +122,10 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
node_id="LTXVAddGuide_V3",
|
node_id="LTXVAddGuide_V3",
|
||||||
category="conditioning/video_models",
|
category="conditioning/video_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input(id="positive"),
|
io.Conditioning.Input("positive"),
|
||||||
io.Conditioning.Input(id="negative"),
|
io.Conditioning.Input("negative"),
|
||||||
io.Vae.Input(id="vae"),
|
io.Vae.Input("vae"),
|
||||||
io.Latent.Input(id="latent"),
|
io.Latent.Input("latent"),
|
||||||
io.Image.Input(
|
io.Image.Input(
|
||||||
id="image",
|
id="image",
|
||||||
tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. "
|
tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. "
|
||||||
@ -141,12 +141,12 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
|
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
|
||||||
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
|
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
|
||||||
),
|
),
|
||||||
io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0, step=0.01),
|
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Conditioning.Output(id="positive_out", display_name="positive"),
|
io.Conditioning.Output(display_name="positive"),
|
||||||
io.Conditioning.Output(id="negative_out", display_name="negative"),
|
io.Conditioning.Output(display_name="negative"),
|
||||||
io.Latent.Output(id="latent_out", display_name="latent"),
|
io.Latent.Output(display_name="latent"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -282,13 +282,13 @@ class LTXVConditioning(io.ComfyNode):
|
|||||||
node_id="LTXVConditioning_V3",
|
node_id="LTXVConditioning_V3",
|
||||||
category="conditioning/video_models",
|
category="conditioning/video_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input(id="positive"),
|
io.Conditioning.Input("positive"),
|
||||||
io.Conditioning.Input(id="negative"),
|
io.Conditioning.Input("negative"),
|
||||||
io.Float.Input(id="frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01),
|
io.Float.Input("frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Conditioning.Output(id="positive_out", display_name="positive"),
|
io.Conditioning.Output(display_name="positive"),
|
||||||
io.Conditioning.Output(id="negative_out", display_name="negative"),
|
io.Conditioning.Output(display_name="negative"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -306,14 +306,14 @@ class LTXVCropGuides(io.ComfyNode):
|
|||||||
node_id="LTXVCropGuides_V3",
|
node_id="LTXVCropGuides_V3",
|
||||||
category="conditioning/video_models",
|
category="conditioning/video_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input(id="positive"),
|
io.Conditioning.Input("positive"),
|
||||||
io.Conditioning.Input(id="negative"),
|
io.Conditioning.Input("negative"),
|
||||||
io.Latent.Input(id="latent"),
|
io.Latent.Input("latent"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Conditioning.Output(id="positive_out", display_name="positive"),
|
io.Conditioning.Output(display_name="positive"),
|
||||||
io.Conditioning.Output(id="negative_out", display_name="negative"),
|
io.Conditioning.Output(display_name="negative"),
|
||||||
io.Latent.Output(id="latent_out", display_name="latent"),
|
io.Latent.Output(display_name="latent"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -342,19 +342,19 @@ class LTXVImgToVideo(io.ComfyNode):
|
|||||||
node_id="LTXVImgToVideo_V3",
|
node_id="LTXVImgToVideo_V3",
|
||||||
category="conditioning/video_models",
|
category="conditioning/video_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input(id="positive"),
|
io.Conditioning.Input("positive"),
|
||||||
io.Conditioning.Input(id="negative"),
|
io.Conditioning.Input("negative"),
|
||||||
io.Vae.Input(id="vae"),
|
io.Vae.Input("vae"),
|
||||||
io.Image.Input(id="image"),
|
io.Image.Input("image"),
|
||||||
io.Int.Input(id="width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||||
io.Int.Input(id="height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||||
io.Int.Input(id="length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8),
|
io.Int.Input("length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8),
|
||||||
io.Int.Input(id="batch_size", default=1, min=1, max=4096),
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0),
|
io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Conditioning.Output(id="positive_out", display_name="positive"),
|
io.Conditioning.Output(display_name="positive"),
|
||||||
io.Conditioning.Output(id="negative_out", display_name="negative"),
|
io.Conditioning.Output(display_name="negative"),
|
||||||
io.Latent.Output(display_name="latent"),
|
io.Latent.Output(display_name="latent"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -390,13 +390,13 @@ class LTXVPreprocess(io.ComfyNode):
|
|||||||
node_id="LTXVPreprocess_V3",
|
node_id="LTXVPreprocess_V3",
|
||||||
category="image",
|
category="image",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input(id="image"),
|
io.Image.Input("image"),
|
||||||
io.Int.Input(
|
io.Int.Input(
|
||||||
id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image."
|
id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image."
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output(id="output_image", display_name="output_image"),
|
io.Image.Output(display_name="output_image"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -415,9 +415,9 @@ class LTXVScheduler(io.ComfyNode):
|
|||||||
node_id="LTXVScheduler_V3",
|
node_id="LTXVScheduler_V3",
|
||||||
category="sampling/custom_sampling/schedulers",
|
category="sampling/custom_sampling/schedulers",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Int.Input(id="steps", default=20, min=1, max=10000),
|
io.Int.Input("steps", default=20, min=1, max=10000),
|
||||||
io.Float.Input(id="max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
|
io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
|
||||||
io.Float.Input(id="base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
|
io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
|
||||||
io.Boolean.Input(
|
io.Boolean.Input(
|
||||||
id="stretch",
|
id="stretch",
|
||||||
default=True,
|
default=True,
|
||||||
@ -431,7 +431,7 @@ class LTXVScheduler(io.ComfyNode):
|
|||||||
step=0.01,
|
step=0.01,
|
||||||
tooltip="The terminal value of the sigmas after stretching.",
|
tooltip="The terminal value of the sigmas after stretching.",
|
||||||
),
|
),
|
||||||
io.Latent.Input(id="latent", optional=True),
|
io.Latent.Input("latent", optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Sigmas.Output(),
|
io.Sigmas.Output(),
|
||||||
@ -478,10 +478,10 @@ class ModelSamplingLTXV(io.ComfyNode):
|
|||||||
node_id="ModelSamplingLTXV_V3",
|
node_id="ModelSamplingLTXV_V3",
|
||||||
category="advanced/model",
|
category="advanced/model",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input(id="model"),
|
io.Model.Input("model"),
|
||||||
io.Float.Input(id="max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
|
io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
|
||||||
io.Float.Input(id="base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
|
io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
|
||||||
io.Latent.Input(id="latent", optional=True),
|
io.Latent.Input("latent", optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(),
|
io.Model.Output(),
|
||||||
|
116
comfy_extras/v3/nodes_lumina2.py
Normal file
116
comfy_extras/v3/nodes_lumina2.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy_api.v3 import io
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextEncodeLumina2(io.ComfyNode):
|
||||||
|
SYSTEM_PROMPT = {
|
||||||
|
"superior": "You are an assistant designed to generate superior images with the superior "
|
||||||
|
"degree of image-text alignment based on textual prompts or user prompts.",
|
||||||
|
"alignment": "You are an assistant designed to generate high-quality images with the "
|
||||||
|
"highest degree of image-text alignment based on textual prompts."
|
||||||
|
}
|
||||||
|
SYSTEM_PROMPT_TIP = "Lumina2 provide two types of system prompts:" \
|
||||||
|
"Superior: You are an assistant designed to generate superior images with the superior "\
|
||||||
|
"degree of image-text alignment based on textual prompts or user prompts. "\
|
||||||
|
"Alignment: You are an assistant designed to generate high-quality images with the highest "\
|
||||||
|
"degree of image-text alignment based on textual prompts."
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CLIPTextEncodeLumina2_V3",
|
||||||
|
display_name="CLIP Text Encode for Lumina2 _V3",
|
||||||
|
category="conditioning",
|
||||||
|
description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "
|
||||||
|
"that can be used to guide the diffusion model towards generating specific images.",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("system_prompt", options=list(cls.SYSTEM_PROMPT.keys()), tooltip=cls.SYSTEM_PROMPT_TIP),
|
||||||
|
io.String.Input("user_prompt", multiline=True, dynamic_prompts=True, tooltip="The text to be encoded."),
|
||||||
|
io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(tooltip="A conditioning containing the embedded text used to guide the diffusion model."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, system_prompt, user_prompt, clip):
|
||||||
|
if clip is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"ERROR: clip input is invalid: None\n\n"
|
||||||
|
"If the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model."
|
||||||
|
)
|
||||||
|
system_prompt = cls.SYSTEM_PROMPT[system_prompt]
|
||||||
|
prompt = f'{system_prompt} <Prompt Start> {user_prompt}'
|
||||||
|
tokens = clip.tokenize(prompt)
|
||||||
|
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||||
|
|
||||||
|
|
||||||
|
class RenormCFG(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="RenormCFG_V3",
|
||||||
|
category="advanced/model",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Float.Input("cfg_trunc", default=100, min=0.0, max=100.0, step=0.01),
|
||||||
|
io.Float.Input("renorm_cfg", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, cfg_trunc, renorm_cfg):
|
||||||
|
def renorm_cfg_func(args):
|
||||||
|
cond_denoised = args["cond_denoised"]
|
||||||
|
uncond_denoised = args["uncond_denoised"]
|
||||||
|
cond_scale = args["cond_scale"]
|
||||||
|
timestep = args["timestep"]
|
||||||
|
x_orig = args["input"]
|
||||||
|
in_channels = model.model.diffusion_model.in_channels
|
||||||
|
|
||||||
|
if timestep[0] < cfg_trunc:
|
||||||
|
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
|
||||||
|
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
|
||||||
|
half_eps = uncond_eps + cond_scale * (cond_eps - uncond_eps)
|
||||||
|
half_rest = cond_rest
|
||||||
|
|
||||||
|
if float(renorm_cfg) > 0.0:
|
||||||
|
ori_pos_norm = torch.linalg.vector_norm(
|
||||||
|
cond_eps,
|
||||||
|
dim=tuple(range(1, len(cond_eps.shape))),
|
||||||
|
keepdim=True
|
||||||
|
)
|
||||||
|
max_new_norm = ori_pos_norm * float(renorm_cfg)
|
||||||
|
new_pos_norm = torch.linalg.vector_norm(
|
||||||
|
half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
|
||||||
|
)
|
||||||
|
if new_pos_norm >= max_new_norm:
|
||||||
|
half_eps = half_eps * (max_new_norm / new_pos_norm)
|
||||||
|
else:
|
||||||
|
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
|
||||||
|
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
|
||||||
|
half_eps = cond_eps
|
||||||
|
half_rest = cond_rest
|
||||||
|
|
||||||
|
cfg_result = torch.cat([half_eps, half_rest], dim=1)
|
||||||
|
|
||||||
|
# cfg_result = uncond_denoised + (cond_denoised - uncond_denoised) * cond_scale
|
||||||
|
|
||||||
|
return x_orig - cfg_result
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_sampler_cfg_function(renorm_cfg_func)
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
CLIPTextEncodeLumina2,
|
||||||
|
RenormCFG,
|
||||||
|
]
|
@ -23,12 +23,12 @@ class ImageRGBToYUV(io.ComfyNode):
|
|||||||
node_id="ImageRGBToYUV_V3",
|
node_id="ImageRGBToYUV_V3",
|
||||||
category="image/batch",
|
category="image/batch",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input(id="image"),
|
io.Image.Input("image"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output(id="Y", display_name="Y"),
|
io.Image.Output(display_name="Y"),
|
||||||
io.Image.Output(id="U", display_name="U"),
|
io.Image.Output(display_name="U"),
|
||||||
io.Image.Output(id="V", display_name="V"),
|
io.Image.Output(display_name="V"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -45,9 +45,9 @@ class ImageYUVToRGB(io.ComfyNode):
|
|||||||
node_id="ImageYUVToRGB_V3",
|
node_id="ImageYUVToRGB_V3",
|
||||||
category="image/batch",
|
category="image/batch",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input(id="Y"),
|
io.Image.Input("Y"),
|
||||||
io.Image.Input(id="U"),
|
io.Image.Input("U"),
|
||||||
io.Image.Input(id="V"),
|
io.Image.Input("V"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output(),
|
io.Image.Output(),
|
||||||
@ -68,9 +68,9 @@ class Morphology(io.ComfyNode):
|
|||||||
display_name="ImageMorphology _V3",
|
display_name="ImageMorphology _V3",
|
||||||
category="image/postprocessing",
|
category="image/postprocessing",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input(id="image"),
|
io.Image.Input("image"),
|
||||||
io.Combo.Input(id="operation", options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"]),
|
io.Combo.Input("operation", options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"]),
|
||||||
io.Int.Input(id="kernel_size", default=3, min=3, max=999, step=1),
|
io.Int.Input("kernel_size", default=3, min=3, max=999, step=1),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output(),
|
io.Image.Output(),
|
||||||
|
@ -33,9 +33,9 @@ class OptimalStepsScheduler(io.ComfyNode):
|
|||||||
node_id="OptimalStepsScheduler_V3",
|
node_id="OptimalStepsScheduler_V3",
|
||||||
category="sampling/custom_sampling/schedulers",
|
category="sampling/custom_sampling/schedulers",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Combo.Input(id="model_type", options=["FLUX", "Wan", "Chroma"]),
|
io.Combo.Input("model_type", options=["FLUX", "Wan", "Chroma"]),
|
||||||
io.Int.Input(id="steps", default=20, min=3, max=1000),
|
io.Int.Input("steps", default=20, min=3, max=1000),
|
||||||
io.Float.Input(id="denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Sigmas.Output(),
|
io.Sigmas.Output(),
|
||||||
|
@ -17,8 +17,8 @@ class PerturbedAttentionGuidance(io.ComfyNode):
|
|||||||
node_id="PerturbedAttentionGuidance_V3",
|
node_id="PerturbedAttentionGuidance_V3",
|
||||||
category="model_patches/unet",
|
category="model_patches/unet",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input(id="model"),
|
io.Model.Input("model"),
|
||||||
io.Float.Input(id="scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01),
|
io.Float.Input("scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(),
|
io.Model.Output(),
|
||||||
|
@ -88,12 +88,12 @@ class PerpNegGuider(io.ComfyNode):
|
|||||||
node_id="PerpNegGuider_V3",
|
node_id="PerpNegGuider_V3",
|
||||||
category="_for_testing",
|
category="_for_testing",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input(id="model"),
|
io.Model.Input("model"),
|
||||||
io.Conditioning.Input(id="positive"),
|
io.Conditioning.Input("positive"),
|
||||||
io.Conditioning.Input(id="negative"),
|
io.Conditioning.Input("negative"),
|
||||||
io.Conditioning.Input(id="empty_conditioning"),
|
io.Conditioning.Input("empty_conditioning"),
|
||||||
io.Float.Input(id="cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
||||||
io.Float.Input(id="neg_scale", default=1.0, min=0.0, max=100.0, step=0.01),
|
io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Guider.Output(),
|
io.Guider.Output(),
|
||||||
|
70
comfy_extras/v3/nodes_tcfg.py
Normal file
70
comfy_extras/v3/nodes_tcfg.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
"""TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy_api.v3 import io
|
||||||
|
|
||||||
|
|
||||||
|
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Drop tangential components from uncond score to align with cond score."""
|
||||||
|
# (B, 1, ...)
|
||||||
|
batch_num = cond_score.shape[0]
|
||||||
|
cond_score_flat = cond_score.reshape(batch_num, 1, -1).float()
|
||||||
|
uncond_score_flat = uncond_score.reshape(batch_num, 1, -1).float()
|
||||||
|
|
||||||
|
# Score matrix A (B, 2, ...)
|
||||||
|
score_matrix = torch.cat((uncond_score_flat, cond_score_flat), dim=1)
|
||||||
|
try:
|
||||||
|
_, _, Vh = torch.linalg.svd(score_matrix, full_matrices=False)
|
||||||
|
except RuntimeError:
|
||||||
|
# Fallback to CPU
|
||||||
|
_, _, Vh = torch.linalg.svd(score_matrix.cpu(), full_matrices=False)
|
||||||
|
|
||||||
|
# Drop the tangential components
|
||||||
|
v1 = Vh[:, 0:1, :].to(uncond_score_flat.device) # (B, 1, ...)
|
||||||
|
uncond_score_td = (uncond_score_flat @ v1.transpose(-2, -1)) * v1
|
||||||
|
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class TCFG(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="TCFG_V3",
|
||||||
|
display_name="Tangential Damping CFG _V3",
|
||||||
|
category="advanced/guidance",
|
||||||
|
description="TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality.",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(display_name="patched_model"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model):
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
def tangential_damping_cfg(args):
|
||||||
|
# Assume [cond, uncond, ...]
|
||||||
|
x = args["input"]
|
||||||
|
conds_out = args["conds_out"]
|
||||||
|
if len(conds_out) <= 1 or None in args["conds"][:2]:
|
||||||
|
# Skip when either cond or uncond is None
|
||||||
|
return conds_out
|
||||||
|
cond_pred = conds_out[0]
|
||||||
|
uncond_pred = conds_out[1]
|
||||||
|
uncond_td = score_tangential_damping(x - cond_pred, x - uncond_pred)
|
||||||
|
uncond_pred_td = x - uncond_td
|
||||||
|
return [cond_pred, uncond_pred_td] + conds_out[2:]
|
||||||
|
|
||||||
|
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
TCFG,
|
||||||
|
]
|
190
comfy_extras/v3/nodes_tomesd.py
Normal file
190
comfy_extras/v3/nodes_tomesd.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
"""Taken from: https://github.com/dbolya/tomesd"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy_api.v3 import io
|
||||||
|
|
||||||
|
|
||||||
|
def do_nothing(x: torch.Tensor, mode:str=None):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def mps_gather_workaround(input, dim, index):
|
||||||
|
if input.shape[-1] == 1:
|
||||||
|
return torch.gather(
|
||||||
|
input.unsqueeze(-1),
|
||||||
|
dim - 1 if dim < 0 else dim,
|
||||||
|
index.unsqueeze(-1)
|
||||||
|
).squeeze(-1)
|
||||||
|
return torch.gather(input, dim, index)
|
||||||
|
|
||||||
|
|
||||||
|
def bipartite_soft_matching_random2d(
|
||||||
|
metric: torch.Tensor,w: int, h: int, sx: int, sy: int, r: int, no_rand: bool = False
|
||||||
|
) -> Tuple[Callable, Callable]:
|
||||||
|
"""
|
||||||
|
Partitions the tokens into src and dst and merges r tokens from src to dst.
|
||||||
|
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
|
||||||
|
Args:
|
||||||
|
- metric [B, N, C]: metric to use for similarity
|
||||||
|
- w: image width in tokens
|
||||||
|
- h: image height in tokens
|
||||||
|
- sx: stride in the x dimension for dst, must divide w
|
||||||
|
- sy: stride in the y dimension for dst, must divide h
|
||||||
|
- r: number of tokens to remove (by merging)
|
||||||
|
- no_rand: if true, disable randomness (use top left corner only)
|
||||||
|
"""
|
||||||
|
B, N, _ = metric.shape
|
||||||
|
|
||||||
|
if r <= 0 or w == 1 or h == 1:
|
||||||
|
return do_nothing, do_nothing
|
||||||
|
|
||||||
|
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
hsy, wsx = h // sy, w // sx
|
||||||
|
|
||||||
|
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
||||||
|
if no_rand:
|
||||||
|
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
|
||||||
|
else:
|
||||||
|
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
|
||||||
|
|
||||||
|
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
|
||||||
|
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
|
||||||
|
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
|
||||||
|
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
|
||||||
|
|
||||||
|
# Image is not divisible by sx or sy so we need to move it into a new buffer
|
||||||
|
if (hsy * sy) < h or (wsx * sx) < w:
|
||||||
|
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
|
||||||
|
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
|
||||||
|
else:
|
||||||
|
idx_buffer = idx_buffer_view
|
||||||
|
|
||||||
|
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
|
||||||
|
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
|
||||||
|
|
||||||
|
# We're finished with these
|
||||||
|
del idx_buffer, idx_buffer_view
|
||||||
|
|
||||||
|
# rand_idx is currently dst|src, so split them
|
||||||
|
num_dst = hsy * wsx
|
||||||
|
a_idx = rand_idx[:, num_dst:, :] # src
|
||||||
|
b_idx = rand_idx[:, :num_dst, :] # dst
|
||||||
|
|
||||||
|
def split(x):
|
||||||
|
C = x.shape[-1]
|
||||||
|
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
|
||||||
|
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
|
||||||
|
return src, dst
|
||||||
|
|
||||||
|
# Cosine similarity between A and B
|
||||||
|
metric = metric / metric.norm(dim=-1, keepdim=True)
|
||||||
|
a, b = split(metric)
|
||||||
|
scores = a @ b.transpose(-1, -2)
|
||||||
|
|
||||||
|
# Can't reduce more than the # tokens in src
|
||||||
|
r = min(a.shape[1], r)
|
||||||
|
|
||||||
|
# Find the most similar greedily
|
||||||
|
node_max, node_idx = scores.max(dim=-1)
|
||||||
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
||||||
|
|
||||||
|
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
|
||||||
|
src_idx = edge_idx[..., :r, :] # Merged Tokens
|
||||||
|
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
|
||||||
|
|
||||||
|
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
||||||
|
src, dst = split(x)
|
||||||
|
n, t1, c = src.shape
|
||||||
|
|
||||||
|
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
|
||||||
|
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
|
||||||
|
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
||||||
|
|
||||||
|
return torch.cat([unm, dst], dim=1)
|
||||||
|
|
||||||
|
def unmerge(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
unm_len = unm_idx.shape[1]
|
||||||
|
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
|
||||||
|
_, _, c = unm.shape
|
||||||
|
|
||||||
|
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
|
||||||
|
|
||||||
|
# Combine back to the original shape
|
||||||
|
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
|
||||||
|
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
|
||||||
|
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
|
||||||
|
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
return merge, unmerge
|
||||||
|
|
||||||
|
|
||||||
|
def get_functions(x, ratio, original_shape):
|
||||||
|
b, c, original_h, original_w = original_shape
|
||||||
|
original_tokens = original_h * original_w
|
||||||
|
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
|
||||||
|
stride_x = 2
|
||||||
|
stride_y = 2
|
||||||
|
max_downsample = 1
|
||||||
|
|
||||||
|
if downsample <= max_downsample:
|
||||||
|
w = int(math.ceil(original_w / downsample))
|
||||||
|
h = int(math.ceil(original_h / downsample))
|
||||||
|
r = int(x.shape[1] * ratio)
|
||||||
|
no_rand = False
|
||||||
|
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
|
||||||
|
return m, u
|
||||||
|
|
||||||
|
def nothing(y):
|
||||||
|
return y
|
||||||
|
|
||||||
|
return nothing, nothing
|
||||||
|
|
||||||
|
|
||||||
|
class TomePatchModel(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="TomePatchModel_V3",
|
||||||
|
category="model_patches/unet",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, ratio):
|
||||||
|
u = None
|
||||||
|
|
||||||
|
def tomesd_m(q, k, v, extra_options):
|
||||||
|
nonlocal u
|
||||||
|
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
|
||||||
|
#however from my basic testing it seems that using q instead gives better results
|
||||||
|
m, u = get_functions(q, ratio, extra_options["original_shape"])
|
||||||
|
return m(q), k, v
|
||||||
|
|
||||||
|
def tomesd_u(n, extra_options):
|
||||||
|
return u(n)
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_attn1_patch(tomesd_m)
|
||||||
|
m.set_model_attn1_output_patch(tomesd_u)
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
TomePatchModel,
|
||||||
|
]
|
32
comfy_extras/v3/nodes_torch_compile.py
Normal file
32
comfy_extras/v3/nodes_torch_compile.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
||||||
|
from comfy_api.v3 import io
|
||||||
|
|
||||||
|
|
||||||
|
class TorchCompileModel(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="TorchCompileModel_V3",
|
||||||
|
category="_for_testing",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Combo.Input("backend", options=["inductor", "cudagraphs"]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, backend):
|
||||||
|
m = model.clone()
|
||||||
|
set_torch_compile_wrapper(model=m, backend=backend)
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
TorchCompileModel,
|
||||||
|
]
|
666
comfy_extras/v3/nodes_train.py
Normal file
666
comfy_extras/v3/nodes_train.py
Normal file
@ -0,0 +1,666 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import safetensors
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
import tqdm
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.samplers
|
||||||
|
import comfy.sd
|
||||||
|
import comfy.utils
|
||||||
|
import comfy_extras.nodes_custom_sampler
|
||||||
|
import folder_paths
|
||||||
|
import node_helpers
|
||||||
|
from comfy.weight_adapter import adapter_maps, adapters
|
||||||
|
from comfy_api.v3 import io, ui
|
||||||
|
|
||||||
|
|
||||||
|
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||||
|
new_dict = {}
|
||||||
|
for k, v in d.items():
|
||||||
|
newv = v
|
||||||
|
if isinstance(v, dict):
|
||||||
|
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
|
||||||
|
elif isinstance(v, torch.Tensor):
|
||||||
|
if full_size is None or v.size(0) == full_size:
|
||||||
|
newv = v[indicies]
|
||||||
|
elif isinstance(v, (list, tuple)) and len(v) == full_size:
|
||||||
|
newv = [v[i] for i in indicies]
|
||||||
|
new_dict[k] = newv
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
|
class TrainSampler(comfy.samplers.Sampler):
|
||||||
|
|
||||||
|
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||||
|
self.loss_fn = loss_fn
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.loss_callback = loss_callback
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.total_steps = total_steps
|
||||||
|
self.grad_acc = grad_acc
|
||||||
|
self.seed = seed
|
||||||
|
self.training_dtype = training_dtype
|
||||||
|
|
||||||
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
|
cond = model_wrap.conds["positive"]
|
||||||
|
dataset_size = sigmas.size(0)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||||
|
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
|
||||||
|
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
|
||||||
|
|
||||||
|
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||||
|
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
|
||||||
|
batch_sigmas = [
|
||||||
|
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||||
|
torch.rand((1,)).item()
|
||||||
|
) for _ in range(min(self.batch_size, dataset_size))
|
||||||
|
]
|
||||||
|
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
||||||
|
|
||||||
|
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||||
|
batch_sigmas,
|
||||||
|
batch_noise,
|
||||||
|
batch_latent,
|
||||||
|
False
|
||||||
|
)
|
||||||
|
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||||
|
torch.zeros_like(batch_sigmas),
|
||||||
|
torch.zeros_like(batch_noise),
|
||||||
|
batch_latent,
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
model_wrap.conds["positive"] = [
|
||||||
|
cond[i] for i in indicies
|
||||||
|
]
|
||||||
|
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
|
||||||
|
|
||||||
|
with torch.autocast(xt.device.type, dtype=self.training_dtype):
|
||||||
|
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
|
||||||
|
loss = self.loss_fn(x0_pred, x0)
|
||||||
|
loss.backward()
|
||||||
|
if self.loss_callback:
|
||||||
|
self.loss_callback(loss.item())
|
||||||
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||||
|
|
||||||
|
if (i + 1) % self.grad_acc == 0:
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return torch.zeros_like(latent_image)
|
||||||
|
|
||||||
|
|
||||||
|
class BiasDiff(torch.nn.Module):
|
||||||
|
def __init__(self, bias):
|
||||||
|
super().__init__()
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
def __call__(self, b):
|
||||||
|
org_dtype = b.dtype
|
||||||
|
return (b.to(self.bias) + self.bias).to(org_dtype)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
return self.bias.nelement() * self.bias.element_size()
|
||||||
|
|
||||||
|
def move_to(self, device):
|
||||||
|
self.to(device=device)
|
||||||
|
return self.passive_memory_usage()
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
|
||||||
|
"""Utility function to load and process a list of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_files: List of image filenames
|
||||||
|
input_dir: Base directory containing the images
|
||||||
|
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Batch of processed images
|
||||||
|
"""
|
||||||
|
if not image_files:
|
||||||
|
raise ValueError("No valid images found in input")
|
||||||
|
|
||||||
|
output_images = []
|
||||||
|
|
||||||
|
for file in image_files:
|
||||||
|
image_path = os.path.join(input_dir, file)
|
||||||
|
img = node_helpers.pillow(Image.open, image_path)
|
||||||
|
|
||||||
|
if img.mode == "I":
|
||||||
|
img = img.point(lambda i: i * (1 / 255))
|
||||||
|
img = img.convert("RGB")
|
||||||
|
|
||||||
|
if w is None and h is None:
|
||||||
|
w, h = img.size[0], img.size[1]
|
||||||
|
|
||||||
|
# Resize image to first image
|
||||||
|
if img.size[0] != w or img.size[1] != h:
|
||||||
|
if resize_method == "Stretch":
|
||||||
|
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||||
|
elif resize_method == "Crop":
|
||||||
|
img = img.crop((0, 0, w, h))
|
||||||
|
elif resize_method == "Pad":
|
||||||
|
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||||
|
elif resize_method == "None":
|
||||||
|
raise ValueError(
|
||||||
|
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
|
||||||
|
)
|
||||||
|
|
||||||
|
img_array = np.array(img).astype(np.float32) / 255.0
|
||||||
|
img_tensor = torch.from_numpy(img_array)[None,]
|
||||||
|
output_images.append(img_tensor)
|
||||||
|
|
||||||
|
return torch.cat(output_images, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_loss_graph(loss_map, steps):
|
||||||
|
width, height = 500, 300
|
||||||
|
img = Image.new("RGB", (width, height), "white")
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
|
||||||
|
scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_map.values()]
|
||||||
|
|
||||||
|
prev_point = (0, height - int(scaled_loss[0] * height))
|
||||||
|
for i, l_v in enumerate(scaled_loss[1:], start=1):
|
||||||
|
x = int(i / (steps - 1) * width)
|
||||||
|
y = height - int(l_v * height)
|
||||||
|
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||||
|
prev_point = (x, y)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None):
|
||||||
|
if result is None:
|
||||||
|
result = []
|
||||||
|
elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
|
||||||
|
result.append(model)
|
||||||
|
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
|
||||||
|
return result
|
||||||
|
name = name or "root"
|
||||||
|
for next_name, child in model.named_children():
|
||||||
|
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def patch(m):
|
||||||
|
if not hasattr(m, "forward"):
|
||||||
|
return
|
||||||
|
org_forward = m.forward
|
||||||
|
def fwd(args, kwargs):
|
||||||
|
return org_forward(*args, **kwargs)
|
||||||
|
def checkpointing_fwd(*args, **kwargs):
|
||||||
|
return torch.utils.checkpoint.checkpoint(
|
||||||
|
fwd, args, kwargs, use_reentrant=False
|
||||||
|
)
|
||||||
|
m.org_forward = org_forward
|
||||||
|
m.forward = checkpointing_fwd
|
||||||
|
|
||||||
|
|
||||||
|
def unpatch(m):
|
||||||
|
if hasattr(m, "org_forward"):
|
||||||
|
m.forward = m.org_forward
|
||||||
|
del m.org_forward
|
||||||
|
|
||||||
|
|
||||||
|
class LoadImageSetFromFolderNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LoadImageSetFromFolderNode_V3",
|
||||||
|
display_name="Load Image Dataset from Folder _V3",
|
||||||
|
category="loaders",
|
||||||
|
description="Loads a batch of images from a directory for training.",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input(
|
||||||
|
"folder", options=folder_paths.get_input_subfolders(), tooltip="The folder to load images from."
|
||||||
|
),
|
||||||
|
io.Combo.Input(
|
||||||
|
"resize_method", options=["None", "Stretch", "Crop", "Pad"], default="None", optional=True
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, folder, resize_method="None"):
|
||||||
|
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||||
|
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
||||||
|
image_files = [
|
||||||
|
f
|
||||||
|
for f in os.listdir(sub_input_dir)
|
||||||
|
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||||
|
]
|
||||||
|
return io.NodeOutput(load_and_process_images(image_files, sub_input_dir, resize_method))
|
||||||
|
|
||||||
|
|
||||||
|
class LoadImageTextSetFromFolderNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LoadImageTextSetFromFolderNode_V3",
|
||||||
|
display_name="Load Image and Text Dataset from Folder _V3",
|
||||||
|
category="loaders",
|
||||||
|
description="Loads a batch of images and caption from a directory for training.",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("folder", options=folder_paths.get_input_subfolders(), tooltip="The folder to load images from."),
|
||||||
|
io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."),
|
||||||
|
io.Combo.Input("resize_method", options=["None", "Stretch", "Crop", "Pad"], default="None", optional=True),
|
||||||
|
io.Int.Input("width", default=-1, min=-1, max=10000, step=1, tooltip="The width to resize the images to. -1 means use the original width.", optional=True),
|
||||||
|
io.Int.Input("height", default=-1, min=-1, max=10000, step=1, tooltip="The height to resize the images to. -1 means use the original height.", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, folder, clip, resize_method="None", width=None, height=None):
|
||||||
|
if clip is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"ERROR: clip input is invalid: None\n\n"
|
||||||
|
"If the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model."
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Loading images from folder: {folder}")
|
||||||
|
|
||||||
|
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||||
|
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
||||||
|
|
||||||
|
image_files = []
|
||||||
|
for item in os.listdir(sub_input_dir):
|
||||||
|
path = os.path.join(sub_input_dir, item)
|
||||||
|
if any(item.lower().endswith(ext) for ext in valid_extensions):
|
||||||
|
image_files.append(path)
|
||||||
|
elif os.path.isdir(path):
|
||||||
|
# Support kohya-ss/sd-scripts folder structure
|
||||||
|
repeat = 1
|
||||||
|
if item.split("_")[0].isdigit():
|
||||||
|
repeat = int(item.split("_")[0])
|
||||||
|
image_files.extend([
|
||||||
|
os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||||
|
] * repeat)
|
||||||
|
|
||||||
|
caption_file_path = [
|
||||||
|
f.replace(os.path.splitext(f)[1], ".txt")
|
||||||
|
for f in image_files
|
||||||
|
]
|
||||||
|
captions = []
|
||||||
|
for caption_file in caption_file_path:
|
||||||
|
caption_path = os.path.join(sub_input_dir, caption_file)
|
||||||
|
if os.path.exists(caption_path):
|
||||||
|
with open(caption_path, "r", encoding="utf-8") as f:
|
||||||
|
caption = f.read().strip()
|
||||||
|
captions.append(caption)
|
||||||
|
else:
|
||||||
|
captions.append("")
|
||||||
|
|
||||||
|
width = width if width != -1 else None
|
||||||
|
height = height if height != -1 else None
|
||||||
|
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height)
|
||||||
|
|
||||||
|
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
|
||||||
|
|
||||||
|
logging.info(f"Encoding captions from {sub_input_dir}.")
|
||||||
|
conditions = []
|
||||||
|
empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
|
||||||
|
for text in captions:
|
||||||
|
if text == "":
|
||||||
|
conditions.append(empty_cond)
|
||||||
|
tokens = clip.tokenize(text)
|
||||||
|
conditions.extend(clip.encode_from_tokens_scheduled(tokens))
|
||||||
|
logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.")
|
||||||
|
return io.NodeOutput(output_tensor, conditions)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraModelLoader(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LoraModelLoader_V3",
|
||||||
|
display_name="Load LoRA Model _V3",
|
||||||
|
category="loaders",
|
||||||
|
description="Load Trained LoRA weights from Train LoRA node.",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model", tooltip="The diffusion model the LoRA will be applied to."),
|
||||||
|
io.LoraModel.Input("lora", tooltip="The LoRA model to apply to the diffusion model."),
|
||||||
|
io.Float.Input("strength_model", default=1.0, min=-100.0, max=100.0, step=0.01, tooltip="How strongly to modify the diffusion model. This value can be negative."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(tooltip="The modified diffusion model."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, lora, strength_model):
|
||||||
|
if strength_model == 0:
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
|
||||||
|
return io.NodeOutput(model_lora)
|
||||||
|
|
||||||
|
|
||||||
|
class LossGraphNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LossGraphNode_V3",
|
||||||
|
display_name="Plot Loss Graph _V3",
|
||||||
|
category="training",
|
||||||
|
description="Plots the loss graph and saves it to the output directory.",
|
||||||
|
is_experimental=True,
|
||||||
|
is_output_node=True,
|
||||||
|
inputs=[
|
||||||
|
io.LossMap.Input("loss"), # TODO: original V1 node has also `default={}` parameter
|
||||||
|
io.String.Input("filename_prefix", default="loss_graph"),
|
||||||
|
],
|
||||||
|
outputs=[],
|
||||||
|
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, loss, filename_prefix):
|
||||||
|
loss_values = loss["loss"]
|
||||||
|
width, height = 800, 480
|
||||||
|
margin = 40
|
||||||
|
|
||||||
|
img = Image.new(
|
||||||
|
"RGB", (width + margin, height + margin), "white"
|
||||||
|
) # Extend canvas
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
min_loss, max_loss = min(loss_values), max(loss_values)
|
||||||
|
scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_values]
|
||||||
|
|
||||||
|
steps = len(loss_values)
|
||||||
|
|
||||||
|
prev_point = (margin, height - int(scaled_loss[0] * height))
|
||||||
|
for i, l_v in enumerate(scaled_loss[1:], start=1):
|
||||||
|
x = margin + int(i / steps * width) # Scale X properly
|
||||||
|
y = height - int(l_v * height)
|
||||||
|
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||||
|
prev_point = (x, y)
|
||||||
|
|
||||||
|
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
|
||||||
|
draw.line(
|
||||||
|
[(margin, height), (width + margin, height)], fill="black", width=2
|
||||||
|
) # X-axis
|
||||||
|
|
||||||
|
try:
|
||||||
|
font = ImageFont.truetype("arial.ttf", 12)
|
||||||
|
except IOError:
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
|
||||||
|
# Add axis labels
|
||||||
|
draw.text((5, height // 2), "Loss", font=font, fill="black")
|
||||||
|
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
|
||||||
|
|
||||||
|
# Add min/max loss values
|
||||||
|
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
|
||||||
|
draw.text(
|
||||||
|
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
|
||||||
|
)
|
||||||
|
return io.NodeOutput(ui=ui.PreviewImage(img, cls=cls))
|
||||||
|
|
||||||
|
|
||||||
|
class SaveLoRA(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SaveLoRA_V3",
|
||||||
|
display_name="Save LoRA Weights _V3",
|
||||||
|
category="loaders",
|
||||||
|
is_experimental=True,
|
||||||
|
is_output_node=True,
|
||||||
|
inputs=[
|
||||||
|
io.LoraModel.Input("lora", tooltip="The LoRA model to save. Do not use the model with LoRA layers."),
|
||||||
|
io.String.Input("prefix", default="loras/ComfyUI_trained_lora", tooltip="The prefix to use for the saved LoRA file."),
|
||||||
|
io.Int.Input("steps", tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, lora, prefix, steps=None):
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||||
|
prefix, folder_paths.get_output_directory()
|
||||||
|
)
|
||||||
|
if steps is None:
|
||||||
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
|
else:
|
||||||
|
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
|
||||||
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
safetensors.torch.save_file(lora, output_checkpoint)
|
||||||
|
return io.NodeOutput()
|
||||||
|
|
||||||
|
|
||||||
|
class TrainLoraNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="TrainLoraNode_V3",
|
||||||
|
display_name="Train LoRA _V3",
|
||||||
|
category="training",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model", tooltip="The model to train the LoRA on."),
|
||||||
|
io.Latent.Input("latents", tooltip="The Latents to use for training, serve as dataset/input of the model."),
|
||||||
|
io.Conditioning.Input("positive", tooltip="The positive conditioning to use for training."),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=10000, step=1, tooltip="The batch size to use for training."),
|
||||||
|
io.Int.Input("grad_accumulation_steps", default=1, min=1, max=1024, step=1, tooltip="The number of gradient accumulation steps to use for training."),
|
||||||
|
io.Int.Input("steps", default=16, min=1, max=100000, tooltip="The number of steps to train the LoRA for."),
|
||||||
|
io.Float.Input("learning_rate", default=0.0005, min=0.0000001, max=1.0, step=0.000001, tooltip="The learning rate to use for training."),
|
||||||
|
io.Int.Input("rank", default=8, min=1, max=128, tooltip="The rank of the LoRA layers."),
|
||||||
|
io.Combo.Input("optimizer", options=["AdamW", "Adam", "SGD", "RMSprop"], default="AdamW", tooltip="The optimizer to use for training."),
|
||||||
|
io.Combo.Input("loss_function", options=["MSE", "L1", "Huber", "SmoothL1"], default="MSE", tooltip="The loss function to use for training."),
|
||||||
|
io.Int.Input("seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)"),
|
||||||
|
io.Combo.Input("training_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for training."),
|
||||||
|
io.Combo.Input("lora_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for lora."),
|
||||||
|
io.Combo.Input("algorithm", options=list(adapter_maps.keys()), default=list(adapter_maps.keys())[0], tooltip="The algorithm to use for training."),
|
||||||
|
io.Boolean.Input("gradient_checkpointing", default=True, tooltip="Use gradient checkpointing for training."),
|
||||||
|
io.Combo.Input("existing_lora", options=folder_paths.get_filename_list("loras") + ["[None]"], default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(display_name="model_with_lora"),
|
||||||
|
io.LoraModel.Output(display_name="lora"),
|
||||||
|
io.LossMap.Output(display_name="loss"),
|
||||||
|
io.Int.Output(display_name="steps"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(
|
||||||
|
cls,
|
||||||
|
model,
|
||||||
|
latents,
|
||||||
|
positive,
|
||||||
|
batch_size,
|
||||||
|
steps,
|
||||||
|
grad_accumulation_steps,
|
||||||
|
learning_rate,
|
||||||
|
rank,
|
||||||
|
optimizer,
|
||||||
|
loss_function,
|
||||||
|
seed,
|
||||||
|
training_dtype,
|
||||||
|
lora_dtype,
|
||||||
|
algorithm,
|
||||||
|
gradient_checkpointing,
|
||||||
|
existing_lora,
|
||||||
|
):
|
||||||
|
mp = model.clone()
|
||||||
|
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||||
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
|
mp.set_model_compute_dtype(dtype)
|
||||||
|
|
||||||
|
latents = latents["samples"].to(dtype)
|
||||||
|
num_images = latents.shape[0]
|
||||||
|
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||||
|
if len(positive) == 1 and num_images > 1:
|
||||||
|
positive = positive * num_images
|
||||||
|
elif len(positive) != num_images:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.inference_mode(False):
|
||||||
|
lora_sd = {}
|
||||||
|
generator = torch.Generator()
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
|
||||||
|
# Load existing LoRA weights if provided
|
||||||
|
existing_weights = {}
|
||||||
|
existing_steps = 0
|
||||||
|
if existing_lora != "[None]":
|
||||||
|
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
||||||
|
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
||||||
|
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
||||||
|
if lora_path:
|
||||||
|
existing_weights = comfy.utils.load_torch_file(lora_path)
|
||||||
|
|
||||||
|
all_weight_adapters = []
|
||||||
|
for n, m in mp.model.named_modules():
|
||||||
|
if hasattr(m, "weight_function"):
|
||||||
|
if m.weight is not None:
|
||||||
|
key = "{}.weight".format(n)
|
||||||
|
shape = m.weight.shape
|
||||||
|
if len(shape) >= 2:
|
||||||
|
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||||
|
dora_scale = existing_weights.get(
|
||||||
|
f"{key}.dora_scale", None
|
||||||
|
)
|
||||||
|
for adapter_cls in adapters:
|
||||||
|
existing_adapter = adapter_cls.load(
|
||||||
|
n, existing_weights, alpha, dora_scale
|
||||||
|
)
|
||||||
|
if existing_adapter is not None:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
existing_adapter = None
|
||||||
|
adapter_cls = adapter_maps[algorithm]
|
||||||
|
|
||||||
|
if existing_adapter is not None:
|
||||||
|
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||||
|
else:
|
||||||
|
# Use LoRA with alpha=1.0 by default
|
||||||
|
train_adapter = adapter_cls.create_train(
|
||||||
|
m.weight, rank=rank, alpha=1.0
|
||||||
|
).to(lora_dtype)
|
||||||
|
for name, parameter in train_adapter.named_parameters():
|
||||||
|
lora_sd[f"{n}.{name}"] = parameter
|
||||||
|
|
||||||
|
mp.add_weight_wrapper(key, train_adapter)
|
||||||
|
all_weight_adapters.append(train_adapter)
|
||||||
|
else:
|
||||||
|
diff = torch.nn.Parameter(
|
||||||
|
torch.zeros(
|
||||||
|
m.weight.shape, dtype=lora_dtype, requires_grad=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
diff_module = BiasDiff(diff)
|
||||||
|
mp.add_weight_wrapper(key, BiasDiff(diff))
|
||||||
|
all_weight_adapters.append(diff_module)
|
||||||
|
lora_sd["{}.diff".format(n)] = diff
|
||||||
|
if hasattr(m, "bias") and m.bias is not None:
|
||||||
|
key = "{}.bias".format(n)
|
||||||
|
bias = torch.nn.Parameter(
|
||||||
|
torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True)
|
||||||
|
)
|
||||||
|
bias_module = BiasDiff(bias)
|
||||||
|
lora_sd["{}.diff_b".format(n)] = bias
|
||||||
|
mp.add_weight_wrapper(key, BiasDiff(bias))
|
||||||
|
all_weight_adapters.append(bias_module)
|
||||||
|
|
||||||
|
if optimizer == "Adam":
|
||||||
|
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
|
||||||
|
elif optimizer == "AdamW":
|
||||||
|
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
|
||||||
|
elif optimizer == "SGD":
|
||||||
|
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
|
||||||
|
elif optimizer == "RMSprop":
|
||||||
|
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
|
||||||
|
|
||||||
|
# Setup loss function based on selection
|
||||||
|
if loss_function == "MSE":
|
||||||
|
criterion = torch.nn.MSELoss()
|
||||||
|
elif loss_function == "L1":
|
||||||
|
criterion = torch.nn.L1Loss()
|
||||||
|
elif loss_function == "Huber":
|
||||||
|
criterion = torch.nn.HuberLoss()
|
||||||
|
elif loss_function == "SmoothL1":
|
||||||
|
criterion = torch.nn.SmoothL1Loss()
|
||||||
|
|
||||||
|
# setup models
|
||||||
|
if gradient_checkpointing:
|
||||||
|
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||||
|
patch(m)
|
||||||
|
mp.model.requires_grad_(False)
|
||||||
|
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
||||||
|
|
||||||
|
# Setup sampler and guider like in test script
|
||||||
|
loss_map = {"loss": []}
|
||||||
|
def loss_callback(loss):
|
||||||
|
loss_map["loss"].append(loss)
|
||||||
|
train_sampler = TrainSampler(
|
||||||
|
criterion,
|
||||||
|
optimizer,
|
||||||
|
loss_callback=loss_callback,
|
||||||
|
batch_size=batch_size,
|
||||||
|
grad_acc=grad_accumulation_steps,
|
||||||
|
total_steps=steps * grad_accumulation_steps,
|
||||||
|
seed=seed,
|
||||||
|
training_dtype=dtype
|
||||||
|
)
|
||||||
|
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||||
|
guider.set_conds(positive) # Set conditioning from input
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
try:
|
||||||
|
# Generate dummy sigmas and noise
|
||||||
|
sigmas = torch.tensor(range(num_images))
|
||||||
|
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||||
|
guider.sample(
|
||||||
|
noise.generate_noise({"samples": latents}),
|
||||||
|
latents,
|
||||||
|
train_sampler,
|
||||||
|
sigmas,
|
||||||
|
seed=noise.seed
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
for m in mp.model.modules():
|
||||||
|
unpatch(m)
|
||||||
|
del train_sampler, optimizer
|
||||||
|
|
||||||
|
for adapter in all_weight_adapters:
|
||||||
|
adapter.requires_grad_(False)
|
||||||
|
|
||||||
|
for param in lora_sd:
|
||||||
|
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
||||||
|
|
||||||
|
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
LoadImageSetFromFolderNode,
|
||||||
|
LoadImageTextSetFromFolderNode,
|
||||||
|
LoraModelLoader,
|
||||||
|
LossGraphNode,
|
||||||
|
SaveLoRA,
|
||||||
|
TrainLoraNode,
|
||||||
|
]
|
106
comfy_extras/v3/nodes_upscale_model.py
Normal file
106
comfy_extras/v3/nodes_upscale_model.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from spandrel import ImageModelDescriptor, ModelLoader
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
from comfy import model_management
|
||||||
|
from comfy_api.v3 import io
|
||||||
|
|
||||||
|
try:
|
||||||
|
from spandrel import MAIN_REGISTRY
|
||||||
|
from spandrel_extra_arches import EXTRA_REGISTRY
|
||||||
|
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
|
||||||
|
logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ImageUpscaleWithModel(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ImageUpscaleWithModel_V3",
|
||||||
|
display_name="Upscale Image (using Model) _V3",
|
||||||
|
category="image/upscaling",
|
||||||
|
inputs=[
|
||||||
|
io.UpscaleModel.Input("upscale_model"),
|
||||||
|
io.Image.Input("image"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, upscale_model, image):
|
||||||
|
device = model_management.get_torch_device()
|
||||||
|
|
||||||
|
memory_required = model_management.module_size(upscale_model.model)
|
||||||
|
memory_required += (512 * 512 * 3) * image.element_size() * max(upscale_model.scale, 1.0) * 384.0 #The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate
|
||||||
|
memory_required += image.nelement() * image.element_size()
|
||||||
|
model_management.free_memory(memory_required, device)
|
||||||
|
|
||||||
|
upscale_model.to(device)
|
||||||
|
in_img = image.movedim(-1,-3).to(device)
|
||||||
|
|
||||||
|
tile = 512
|
||||||
|
overlap = 32
|
||||||
|
|
||||||
|
oom = True
|
||||||
|
while oom:
|
||||||
|
try:
|
||||||
|
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(
|
||||||
|
in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap
|
||||||
|
)
|
||||||
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
s = comfy.utils.tiled_scale(
|
||||||
|
in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar
|
||||||
|
)
|
||||||
|
oom = False
|
||||||
|
except model_management.OOM_EXCEPTION as e:
|
||||||
|
tile //= 2
|
||||||
|
if tile < 128:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
upscale_model.to("cpu")
|
||||||
|
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
||||||
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
|
|
||||||
|
class UpscaleModelLoader(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="UpscaleModelLoader_V3",
|
||||||
|
display_name="Load Upscale Model _V3",
|
||||||
|
category="loaders",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("model_name", options=folder_paths.get_filename_list("upscale_models")),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.UpscaleModel.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model_name):
|
||||||
|
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
|
||||||
|
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||||
|
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
|
||||||
|
out = ModelLoader().load_from_state_dict(sd).eval()
|
||||||
|
|
||||||
|
if not isinstance(out, ImageModelDescriptor):
|
||||||
|
raise Exception("Upscale model must be a single-image model.")
|
||||||
|
|
||||||
|
return io.NodeOutput(out)
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
ImageUpscaleWithModel,
|
||||||
|
UpscaleModelLoader,
|
||||||
|
]
|
232
comfy_extras/v3/nodes_video_model.py
Normal file
232
comfy_extras/v3/nodes_video_model.py
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import comfy.sd
|
||||||
|
import comfy.utils
|
||||||
|
import comfy_extras.nodes_model_merging
|
||||||
|
import folder_paths
|
||||||
|
import node_helpers
|
||||||
|
import nodes
|
||||||
|
from comfy_api.v3 import io
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningSetAreaPercentageVideo(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ConditioningSetAreaPercentageVideo_V3",
|
||||||
|
category="conditioning",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("conditioning"),
|
||||||
|
io.Float.Input("width", default=1.0, min=0, max=1.0, step=0.01),
|
||||||
|
io.Float.Input("height", default=1.0, min=0, max=1.0, step=0.01),
|
||||||
|
io.Float.Input("temporal", default=1.0, min=0, max=1.0, step=0.01),
|
||||||
|
io.Float.Input("x", default=0, min=0, max=1.0, step=0.01),
|
||||||
|
io.Float.Input("y", default=0, min=0, max=1.0, step=0.01),
|
||||||
|
io.Float.Input("z", default=0, min=0, max=1.0, step=0.01),
|
||||||
|
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, conditioning, width, height, temporal, x, y, z, strength):
|
||||||
|
c = node_helpers.conditioning_set_values(
|
||||||
|
conditioning,
|
||||||
|
{
|
||||||
|
"area": ("percentage", temporal, height, width, z, y, x),
|
||||||
|
"strength": strength,
|
||||||
|
"set_area_to_bounds": False
|
||||||
|
,}
|
||||||
|
)
|
||||||
|
return io.NodeOutput(c)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageOnlyCheckpointLoader(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ImageOnlyCheckpointLoader_V3",
|
||||||
|
display_name="Image Only Checkpoint Loader (img2vid model) _V3",
|
||||||
|
category="loaders/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("checkpoints")),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
io.ClipVision.Output(),
|
||||||
|
io.Vae.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, ckpt_name):
|
||||||
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
|
out = comfy.sd.load_checkpoint_guess_config(
|
||||||
|
ckpt_path,
|
||||||
|
output_vae=True,
|
||||||
|
output_clip=False,
|
||||||
|
output_clipvision=True,
|
||||||
|
embedding_directory=folder_paths.get_folder_paths("embeddings"),
|
||||||
|
)
|
||||||
|
return io.NodeOutput(out[0], out[3], out[2])
|
||||||
|
|
||||||
|
|
||||||
|
class ImageOnlyCheckpointSave(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ImageOnlyCheckpointSave_V3",
|
||||||
|
category="advanced/model_merging",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.ClipVision.Input("clip_vision"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.String.Input("filename_prefix", default="checkpoints/ComfyUI"),
|
||||||
|
],
|
||||||
|
outputs=[],
|
||||||
|
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, clip_vision, vae, filename_prefix):
|
||||||
|
output_dir = folder_paths.get_output_directory()
|
||||||
|
comfy_extras.nodes_model_merging.save_checkpoint(
|
||||||
|
model,
|
||||||
|
clip_vision=clip_vision,
|
||||||
|
vae=vae,
|
||||||
|
filename_prefix=filename_prefix,
|
||||||
|
output_dir=output_dir,
|
||||||
|
prompt=cls.hidden.prompt,
|
||||||
|
extra_pnginfo=cls.hidden.extra_pnginfo,
|
||||||
|
)
|
||||||
|
return io.NodeOutput()
|
||||||
|
|
||||||
|
|
||||||
|
class SVD_img2vid_Conditioning(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SVD_img2vid_Conditioning_V3",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.ClipVision.Input("clip_vision"),
|
||||||
|
io.Image.Input("init_image"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||||
|
io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||||
|
io.Int.Input("video_frames", default=14, min=1, max=4096),
|
||||||
|
io.Int.Input("motion_bucket_id", default=127, min=1, max=1023),
|
||||||
|
io.Int.Input("fps", default=6, min=1, max=1024),
|
||||||
|
io.Float.Input("augmentation_level", default=0.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
|
||||||
|
output = clip_vision.encode_image(init_image)
|
||||||
|
pooled = output.image_embeds.unsqueeze(0)
|
||||||
|
pixels = comfy.utils.common_upscale(
|
||||||
|
init_image.movedim(-1,1), width, height, "bilinear", "center"
|
||||||
|
).movedim(1,-1)
|
||||||
|
encode_pixels = pixels[:,:,:,:3]
|
||||||
|
if augmentation_level > 0:
|
||||||
|
encode_pixels += torch.randn_like(pixels) * augmentation_level
|
||||||
|
t = vae.encode(encode_pixels)
|
||||||
|
positive = [
|
||||||
|
[
|
||||||
|
pooled,
|
||||||
|
{"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
negative = [
|
||||||
|
[
|
||||||
|
torch.zeros_like(pooled),
|
||||||
|
{"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
|
||||||
|
return io.NodeOutput(positive, negative, {"samples":latent})
|
||||||
|
|
||||||
|
|
||||||
|
class VideoLinearCFGGuidance(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="VideoLinearCFGGuidance_V3",
|
||||||
|
category="sampling/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, min_cfg):
|
||||||
|
def linear_cfg(args):
|
||||||
|
cond = args["cond"]
|
||||||
|
uncond = args["uncond"]
|
||||||
|
cond_scale = args["cond_scale"]
|
||||||
|
|
||||||
|
scale = torch.linspace(
|
||||||
|
min_cfg, cond_scale, cond.shape[0], device=cond.device
|
||||||
|
).reshape((cond.shape[0], 1, 1, 1))
|
||||||
|
return uncond + scale * (cond - uncond)
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_sampler_cfg_function(linear_cfg)
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoTriangleCFGGuidance(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="VideoTriangleCFGGuidance_V3",
|
||||||
|
category="sampling/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, min_cfg):
|
||||||
|
def linear_cfg(args):
|
||||||
|
cond = args["cond"]
|
||||||
|
uncond = args["uncond"]
|
||||||
|
cond_scale = args["cond_scale"]
|
||||||
|
period = 1.0
|
||||||
|
values = torch.linspace(0, 1, cond.shape[0], device=cond.device)
|
||||||
|
values = 2 * (values / period - torch.floor(values / period + 0.5)).abs()
|
||||||
|
scale = (values * (cond_scale - min_cfg) + min_cfg).reshape((cond.shape[0], 1, 1, 1))
|
||||||
|
|
||||||
|
return uncond + scale * (cond - uncond)
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_sampler_cfg_function(linear_cfg)
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
ConditioningSetAreaPercentageVideo,
|
||||||
|
ImageOnlyCheckpointLoader,
|
||||||
|
ImageOnlyCheckpointSave,
|
||||||
|
SVD_img2vid_Conditioning,
|
||||||
|
VideoLinearCFGGuidance,
|
||||||
|
VideoTriangleCFGGuidance,
|
||||||
|
]
|
@ -74,7 +74,8 @@ if not args.cuda_malloc:
|
|||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
version = module.__version__
|
version = module.__version__
|
||||||
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
|
|
||||||
|
if int(version[0]) >= 2 and "+cu" in version: #enable by default for torch version 2.0 and up only on cuda torch
|
||||||
args.cuda_malloc = cuda_malloc_supported()
|
args.cuda_malloc = cuda_malloc_supported()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
14
nodes.py
14
nodes.py
@ -2320,9 +2320,17 @@ def init_builtin_extra_nodes():
|
|||||||
"v3/nodes_fresca.py",
|
"v3/nodes_fresca.py",
|
||||||
"v3/nodes_gits.py",
|
"v3/nodes_gits.py",
|
||||||
"v3/nodes_hidream.py",
|
"v3/nodes_hidream.py",
|
||||||
|
"v3/nodes_hunyuan.py",
|
||||||
|
"v3/nodes_hypernetwork.py",
|
||||||
|
"v3/nodes_hypertile.py",
|
||||||
"v3/nodes_images.py",
|
"v3/nodes_images.py",
|
||||||
|
"v3/nodes_ip2p.py",
|
||||||
"v3/nodes_latent.py",
|
"v3/nodes_latent.py",
|
||||||
|
"v3/nodes_load_3d.py",
|
||||||
|
"v3/nodes_lora_extract.py",
|
||||||
|
"v3/nodes_lotus.py",
|
||||||
"v3/nodes_lt.py",
|
"v3/nodes_lt.py",
|
||||||
|
"v3/nodes_lumina2.py",
|
||||||
"v3/nodes_mask.py",
|
"v3/nodes_mask.py",
|
||||||
"v3/nodes_mochi.py",
|
"v3/nodes_mochi.py",
|
||||||
"v3/nodes_model_advanced.py",
|
"v3/nodes_model_advanced.py",
|
||||||
@ -2342,7 +2350,13 @@ def init_builtin_extra_nodes():
|
|||||||
"v3/nodes_sdupscale.py",
|
"v3/nodes_sdupscale.py",
|
||||||
"v3/nodes_slg.py",
|
"v3/nodes_slg.py",
|
||||||
"v3/nodes_stable_cascade.py",
|
"v3/nodes_stable_cascade.py",
|
||||||
|
"v3/nodes_tcfg.py",
|
||||||
|
"v3/nodes_tomesd.py",
|
||||||
|
"v3/nodes_torch_compile.py",
|
||||||
|
"v3/nodes_train.py",
|
||||||
|
"v3/nodes_upscale_model.py",
|
||||||
"v3/nodes_video.py",
|
"v3/nodes_video.py",
|
||||||
|
"v3/nodes_video_model.py",
|
||||||
"v3/nodes_wan.py",
|
"v3/nodes_wan.py",
|
||||||
"v3/nodes_webcam.py",
|
"v3/nodes_webcam.py",
|
||||||
]
|
]
|
||||||
|
@ -554,6 +554,7 @@ class PromptServer():
|
|||||||
ram_free = comfy.model_management.get_free_memory(cpu_device)
|
ram_free = comfy.model_management.get_free_memory(cpu_device)
|
||||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||||
|
required_frontend_version = FrontendManager.get_required_frontend_version()
|
||||||
|
|
||||||
system_stats = {
|
system_stats = {
|
||||||
"system": {
|
"system": {
|
||||||
@ -561,6 +562,7 @@ class PromptServer():
|
|||||||
"ram_total": ram_total,
|
"ram_total": ram_total,
|
||||||
"ram_free": ram_free,
|
"ram_free": ram_free,
|
||||||
"comfyui_version": __version__,
|
"comfyui_version": __version__,
|
||||||
|
"required_frontend_version": required_frontend_version,
|
||||||
"python_version": sys.version,
|
"python_version": sys.version,
|
||||||
"pytorch_version": comfy.model_management.torch_version,
|
"pytorch_version": comfy.model_management.torch_version,
|
||||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import pytest
|
import pytest
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch, mock_open
|
||||||
|
|
||||||
from app.frontend_management import (
|
from app.frontend_management import (
|
||||||
FrontendManager,
|
FrontendManager,
|
||||||
@ -172,3 +172,36 @@ def test_init_frontend_fallback_on_error():
|
|||||||
# Assert
|
# Assert
|
||||||
assert frontend_path == "/default/path"
|
assert frontend_path == "/default/path"
|
||||||
mock_check.assert_called_once()
|
mock_check.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_frontend_version():
|
||||||
|
# Arrange
|
||||||
|
expected_version = "1.25.0"
|
||||||
|
mock_requirements_content = """torch
|
||||||
|
torchsde
|
||||||
|
comfyui-frontend-package==1.25.0
|
||||||
|
other-package==1.0.0
|
||||||
|
numpy"""
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with patch("builtins.open", mock_open(read_data=mock_requirements_content)):
|
||||||
|
version = FrontendManager.get_required_frontend_version()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert version == expected_version
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_frontend_version_invalid_semver():
|
||||||
|
# Arrange
|
||||||
|
mock_requirements_content = """torch
|
||||||
|
torchsde
|
||||||
|
comfyui-frontend-package==1.29.3.75
|
||||||
|
other-package==1.0.0
|
||||||
|
numpy"""
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with patch("builtins.open", mock_open(read_data=mock_requirements_content)):
|
||||||
|
version = FrontendManager.get_required_frontend_version()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert version is None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user