Merge branch 'v3-definition' into v3-definition-wip

This commit is contained in:
Jedrzej Kosinski 2025-07-24 16:00:58 -07:00
commit 9bd3faaf1f
42 changed files with 2897 additions and 183 deletions

View File

@ -294,6 +294,13 @@ For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html) 2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
3. Launch ComfyUI by running `python main.py` 3. Launch ComfyUI by running `python main.py`
#### Iluvatar Corex
For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method:
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
2. Launch ComfyUI by running `python main.py`
# Running # Running
```python main.py``` ```python main.py```

View File

@ -29,18 +29,48 @@ def frontend_install_warning_message():
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead. This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
""".strip() """.strip()
def parse_version(version: str) -> tuple[int, int, int]:
return tuple(map(int, version.split(".")))
def is_valid_version(version: str) -> bool:
"""Validate if a string is a valid semantic version (X.Y.Z format)."""
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
return bool(re.match(pattern, version))
def get_installed_frontend_version():
"""Get the currently installed frontend package version."""
frontend_version_str = version("comfyui-frontend-package")
return frontend_version_str
def get_required_frontend_version():
"""Get the required frontend version from requirements.txt."""
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-frontend-package=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-frontend-package not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required frontend version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
def check_frontend_version(): def check_frontend_version():
"""Check if the frontend version is up to date.""" """Check if the frontend version is up to date."""
def parse_version(version: str) -> tuple[int, int, int]:
return tuple(map(int, version.split(".")))
try: try:
frontend_version_str = version("comfyui-frontend-package") frontend_version_str = get_installed_frontend_version()
frontend_version = parse_version(frontend_version_str) frontend_version = parse_version(frontend_version_str)
with open(requirements_path, "r", encoding="utf-8") as f: required_frontend_str = get_required_frontend_version()
required_frontend = parse_version(f.readline().split("=")[-1]) required_frontend = parse_version(required_frontend_str)
if frontend_version < required_frontend: if frontend_version < required_frontend:
app.logger.log_startup_warning( app.logger.log_startup_warning(
f""" f"""
@ -168,6 +198,11 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager: class FrontendManager:
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions") CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
@classmethod
def get_required_frontend_version(cls) -> str:
"""Get the required frontend package version."""
return get_required_frontend_version()
@classmethod @classmethod
def default_frontend_path(cls) -> str: def default_frontend_path(cls) -> str:
try: try:

View File

@ -1210,39 +1210,21 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
return x_next return x_next
@torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
extra_args = {} if extra_args is None else extra_args
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
sigma_hat = sigmas[i]
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, temp[0])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
# Euler method
x = denoised + d * sigmas[i + 1]
return x
@torch.no_grad() @torch.no_grad()
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps.""" """Ancestral sampling with Euler method steps (CFG++)."""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None) seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
temp = [0] model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
uncond_denoised = None
def post_cfg_function(args): def post_cfg_function(args):
temp[0] = args["uncond_denoised"] nonlocal uncond_denoised
uncond_denoised = args["uncond_denoised"]
return args["denoised"] return args["denoised"]
model_options = extra_args.get("model_options", {}).copy() model_options = extra_args.get("model_options", {}).copy()
@ -1251,15 +1233,33 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args) denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None: if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], temp[0]) if sigmas[i + 1] == 0:
# Euler method # Denoising step
x = denoised + d * sigma_down x = denoised
if sigmas[i + 1] > 0: else:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
# DDIM stochastic sampling
sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
sigma_down = alpha_t * sigma_down
# Euler method
x = alpha_t * denoised + sigma_down * d
if eta > 0 and s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x return x
@torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Euler method steps (CFG++)."""
return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps.""" """Ancestral sampling with DPM-Solver++(2S) second-order steps."""

View File

@ -128,6 +128,11 @@ try:
except: except:
mlu_available = False mlu_available = False
try:
ixuca_available = hasattr(torch, "corex")
except:
ixuca_available = False
if args.cpu: if args.cpu:
cpu_state = CPUState.CPU cpu_state = CPUState.CPU
@ -151,6 +156,12 @@ def is_mlu():
return True return True
return False return False
def is_ixuca():
global ixuca_available
if ixuca_available:
return True
return False
def get_torch_device(): def get_torch_device():
global directml_enabled global directml_enabled
global cpu_state global cpu_state
@ -289,7 +300,7 @@ try:
if torch_version_numeric[0] >= 2: if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if is_intel_xpu() or is_ascend_npu() or is_mlu(): if is_intel_xpu() or is_ascend_npu() or is_mlu() or is_ixuca():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
except: except:
@ -1045,6 +1056,8 @@ def xformers_enabled():
return False return False
if is_mlu(): if is_mlu():
return False return False
if is_ixuca():
return False
if directml_enabled: if directml_enabled:
return False return False
return XFORMERS_IS_AVAILABLE return XFORMERS_IS_AVAILABLE
@ -1080,6 +1093,8 @@ def pytorch_attention_flash_attention():
return True return True
if is_amd(): if is_amd():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
if is_ixuca():
return True
return False return False
def force_upcast_attention_dtype(): def force_upcast_attention_dtype():
@ -1205,6 +1220,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_mlu(): if is_mlu():
return True return True
if is_ixuca():
return True
if torch.version.hip: if torch.version.hip:
return True return True
@ -1268,6 +1286,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_ascend_npu(): if is_ascend_npu():
return True return True
if is_ixuca():
return True
if is_amd(): if is_amd():
arch = torch.cuda.get_device_properties(device).gcnArchName arch = torch.cuda.get_device_properties(device).gcnArchName
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16 if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16

View File

@ -15,9 +15,20 @@ adapters: list[type[WeightAdapterBase]] = [
OFTAdapter, OFTAdapter,
BOFTAdapter, BOFTAdapter,
] ]
adapter_maps: dict[str, type[WeightAdapterBase]] = {
"LoRA": LoRAAdapter,
"LoHa": LoHaAdapter,
"LoKr": LoKrAdapter,
"OFT": OFTAdapter,
## We disable not implemented algo for now
# "GLoRA": GLoRAAdapter,
# "BOFT": BOFTAdapter,
}
__all__ = [ __all__ = [
"WeightAdapterBase", "WeightAdapterBase",
"WeightAdapterTrainBase", "WeightAdapterTrainBase",
"adapters" "adapters",
"adapter_maps",
] + [a.__name__ for a in adapters] ] + [a.__name__ for a in adapters]

View File

@ -133,3 +133,43 @@ def tucker_weight_from_conv(up, down, mid):
def tucker_weight(wa, wb, t): def tucker_weight(wa, wb, t):
temp = torch.einsum("i j ..., j r -> i r ...", t, wb) temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
return torch.einsum("i j ..., i r -> r j ...", temp, wa) return torch.einsum("i j ..., i r -> r j ...", temp, wa)
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
"""
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
examples)
factor
-1 2 4 8 16 ...
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
"""
if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2:
m = factor
n = dimension // factor
if m > n:
n, m = m, n
return m, n
if factor < 0:
factor = dimension
m, n = 1, dimension
length = m + n
while m < n:
new_m = m + 1
while dimension % new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m > factor:
break
else:
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n

View File

@ -3,7 +3,120 @@ from typing import Optional
import torch import torch
import comfy.model_management import comfy.model_management
from .base import WeightAdapterBase, weight_decompose from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
class HadaWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
return diff_weight
@staticmethod
def backward(ctx, grad_out):
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = grad_out * (w2u @ w2d)
grad_w1u = temp @ w1d.T
grad_w1d = w1u.T @ temp
temp = grad_out * (w1u @ w1d)
grad_w2u = temp @ w2d.T
grad_w2d = w2u.T @ temp
del temp
return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None
class HadaWeightTucker(torch.autograd.Function):
@staticmethod
def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)):
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
return rebuild1 * rebuild2 * scale
@staticmethod
def backward(ctx, grad_out):
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
grad_w = rebuild * grad_out
del rebuild
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
del grad_w, temp
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
del grad_temp
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
grad_w = rebuild * grad_out
del rebuild
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
del grad_w, temp
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
del grad_temp
return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None
class LohaDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
# Unpack weights tuple from LoHaAdapter
w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights
# Create trainable parameters
self.hada_w1_a = torch.nn.Parameter(w1a)
self.hada_w1_b = torch.nn.Parameter(w1b)
self.hada_w2_a = torch.nn.Parameter(w2a)
self.hada_w2_b = torch.nn.Parameter(w2b)
self.use_tucker = False
if t1 is not None and t2 is not None:
self.use_tucker = True
self.hada_t1 = torch.nn.Parameter(t1)
self.hada_t2 = torch.nn.Parameter(t2)
else:
# Keep the attributes for consistent access
self.hada_t1 = None
self.hada_t2 = None
# Store rank and non-trainable alpha
self.rank = w1b.shape[0]
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
def __call__(self, w):
org_dtype = w.dtype
scale = self.alpha / self.rank
if self.use_tucker:
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
else:
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
# Add the scaled difference to the original weight
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
return weight.to(org_dtype)
def passive_memory_usage(self):
"""Calculates memory usage of the trainable parameters."""
return sum(param.numel() * param.element_size() for param in self.parameters())
class LoHaAdapter(WeightAdapterBase): class LoHaAdapter(WeightAdapterBase):
@ -13,6 +126,25 @@ class LoHaAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys self.loaded_keys = loaded_keys
self.weights = weights self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
torch.nn.init.normal_(mat1, 0.1)
torch.nn.init.constant_(mat2, 0.0)
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
torch.nn.init.normal_(mat3, 0.1)
torch.nn.init.normal_(mat4, 0.01)
return LohaDiff(
(mat1, mat2, alpha, mat3, mat4, None, None, None)
)
def to_train(self):
return LohaDiff(self.weights)
@classmethod @classmethod
def load( def load(
cls, cls,

View File

@ -3,7 +3,77 @@ from typing import Optional
import torch import torch
import comfy.model_management import comfy.model_management
from .base import WeightAdapterBase, weight_decompose from .base import (
WeightAdapterBase,
WeightAdapterTrainBase,
weight_decompose,
factorization,
)
class LokrDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
self.use_tucker = False
if lokr_w1_a is not None:
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1]
self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a)
self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b)
self.w1_rebuild = True
self.ranka = rank_a
if lokr_w2_a is not None:
_, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1]
rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1]
self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a)
self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b)
if lokr_t2 is not None:
self.use_tucker = True
self.lokr_t2 = torch.nn.Parameter(lokr_t2)
self.w2_rebuild = True
self.rankb = rank_b
if lokr_w1 is not None:
self.lokr_w1 = torch.nn.Parameter(lokr_w1)
self.w1_rebuild = False
if lokr_w2 is not None:
self.lokr_w2 = torch.nn.Parameter(lokr_w2)
self.w2_rebuild = False
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
@property
def w1(self):
if self.w1_rebuild:
return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka)
else:
return self.lokr_w1
@property
def w2(self):
if self.w2_rebuild:
if self.use_tucker:
w2 = torch.einsum(
'i j k l, j r, i p -> p r k l',
self.lokr_t2,
self.lokr_w2_b,
self.lokr_w2_a
)
else:
w2 = self.lokr_w2_a @ self.lokr_w2_b
return w2 * (self.alpha / self.rankb)
else:
return self.lokr_w2
def __call__(self, w):
diff = torch.kron(self.w1, self.w2)
return w + diff.reshape(w.shape).to(w)
def passive_memory_usage(self):
return sum(param.numel() * param.element_size() for param in self.parameters())
class LoKrAdapter(WeightAdapterBase): class LoKrAdapter(WeightAdapterBase):
@ -13,6 +83,20 @@ class LoKrAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys self.loaded_keys = loaded_keys
self.weights = weights self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
out1, out2 = factorization(out_dim, rank)
in1, in2 = factorization(in_dim, rank)
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
torch.nn.init.constant_(mat1, 0.0)
return LokrDiff(
(mat1, mat2, alpha, None, None, None, None, None, None)
)
@classmethod @classmethod
def load( def load(
cls, cls,

View File

@ -3,7 +3,58 @@ from typing import Optional
import torch import torch
import comfy.model_management import comfy.model_management
from .base import WeightAdapterBase, weight_decompose from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
class OFTDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
# Unpack weights tuple from LoHaAdapter
blocks, rescale, alpha, _ = weights
# Create trainable parameters
self.oft_blocks = torch.nn.Parameter(blocks)
if rescale is not None:
self.rescale = torch.nn.Parameter(rescale)
self.rescaled = True
else:
self.rescaled = False
self.block_num, self.block_size, _ = blocks.shape
self.constraint = float(alpha)
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
def __call__(self, w):
org_dtype = w.dtype
I = torch.eye(self.block_size, device=self.oft_blocks.device)
## generate r
# for Q = -Q^T
q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
normed_q = q
if self.constraint:
q_norm = torch.norm(q) + 1e-8
if q_norm > self.constraint:
normed_q = q * self.constraint / q_norm
# use float() to prevent unsupported type
r = (I + normed_q) @ (I - normed_q).float().inverse()
## Apply chunked matmul on weight
_, *shape = w.shape
org_weight = w.to(dtype=r.dtype)
org_weight = org_weight.unflatten(0, (self.block_num, self.block_size))
# Init R=0, so add I on it to ensure the output of step0 is original model output
weight = torch.einsum(
"k n m, k n ... -> k m ...",
r,
org_weight,
).flatten(0, 1)
if self.rescaled:
weight = self.rescale * weight
return weight.to(org_dtype)
def passive_memory_usage(self):
"""Calculates memory usage of the trainable parameters."""
return sum(param.numel() * param.element_size() for param in self.parameters())
class OFTAdapter(WeightAdapterBase): class OFTAdapter(WeightAdapterBase):
@ -13,6 +64,18 @@ class OFTAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys self.loaded_keys = loaded_keys
self.weights = weights self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
block_size, block_num = factorization(out_dim, rank)
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
return OFTDiff(
(block, None, alpha, None)
)
def to_train(self):
return OFTDiff(self.weights)
@classmethod @classmethod
def load( def load(
cls, cls,
@ -60,6 +123,8 @@ class OFTAdapter(WeightAdapterBase):
blocks = v[0] blocks = v[0]
rescale = v[1] rescale = v[1]
alpha = v[2] alpha = v[2]
if alpha is None:
alpha = 0
dora_scale = v[3] dora_scale = v[3]
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)

View File

@ -656,9 +656,34 @@ class Accumulation(ComfyTypeIO):
accum: list[Any] accum: list[Any]
Type = AccumulationDict Type = AccumulationDict
@comfytype(io_type="LOAD3D_CAMERA") @comfytype(io_type="LOAD3D_CAMERA")
class Load3DCamera(ComfyTypeIO): class Load3DCamera(ComfyTypeIO):
Type = Any # TODO: figure out type for this; in code, only described as image['camera_info'], gotten from a LOAD_3D or LOAD_3D_ANIMATION type class CameraInfo(TypedDict):
position: dict[str, float | int]
target: dict[str, float | int]
zoom: int
cameraType: str
Type = CameraInfo
@comfytype(io_type="LOAD_3D")
class Load3D(ComfyTypeIO):
"""3D models are stored as a dictionary."""
class Model3DDict(TypedDict):
image: str
mask: str
normal: str
camera_info: Load3DCamera.CameraInfo
recording: NotRequired[str]
Type = Model3DDict
@comfytype(io_type="LOAD_3D_ANIMATION")
class Load3DAnimation(Load3D):
...
@comfytype(io_type="PHOTOMAKER") @comfytype(io_type="PHOTOMAKER")

View File

@ -475,11 +475,12 @@ class PreviewVideo(_UIOutput):
class PreviewUI3D(_UIOutput): class PreviewUI3D(_UIOutput):
def __init__(self, values: list[SavedResult | dict], **kwargs): def __init__(self, model_file, camera_info, **kwargs):
self.values = values self.model_file = model_file
self.camera_info = camera_info
def as_dict(self): def as_dict(self):
return {"3d": self.values} return {"result": [self.model_file, self.camera_info]}
class PreviewText(_UIOutput): class PreviewText(_UIOutput):

View File

@ -20,7 +20,7 @@ import folder_paths
import node_helpers import node_helpers
from comfy.cli_args import args from comfy.cli_args import args
from comfy.comfy_types.node_typing import IO from comfy.comfy_types.node_typing import IO
from comfy.weight_adapter import adapters from comfy.weight_adapter import adapters, adapter_maps
def make_batch_extra_option_dict(d, indicies, full_size=None): def make_batch_extra_option_dict(d, indicies, full_size=None):
@ -39,13 +39,13 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
class TrainSampler(comfy.samplers.Sampler): class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.optimizer = optimizer self.optimizer = optimizer
self.loss_callback = loss_callback self.loss_callback = loss_callback
self.batch_size = batch_size self.batch_size = batch_size
self.total_steps = total_steps self.total_steps = total_steps
self.grad_acc = grad_acc
self.seed = seed self.seed = seed
self.training_dtype = training_dtype self.training_dtype = training_dtype
@ -92,8 +92,9 @@ class TrainSampler(comfy.samplers.Sampler):
self.loss_callback(loss.item()) self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"}) pbar.set_postfix({"loss": f"{loss.item():.4f}"})
self.optimizer.step() if (i+1) % self.grad_acc == 0:
self.optimizer.zero_grad() self.optimizer.step()
self.optimizer.zero_grad()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return torch.zeros_like(latent_image) return torch.zeros_like(latent_image)
@ -419,6 +420,16 @@ class TrainLoraNode:
"tooltip": "The batch size to use for training.", "tooltip": "The batch size to use for training.",
}, },
), ),
"grad_accumulation_steps": (
IO.INT,
{
"default": 1,
"min": 1,
"max": 1024,
"step": 1,
"tooltip": "The number of gradient accumulation steps to use for training.",
}
),
"steps": ( "steps": (
IO.INT, IO.INT,
{ {
@ -478,6 +489,17 @@ class TrainLoraNode:
["bf16", "fp32"], ["bf16", "fp32"],
{"default": "bf16", "tooltip": "The dtype to use for lora."}, {"default": "bf16", "tooltip": "The dtype to use for lora."},
), ),
"algorithm": (
list(adapter_maps.keys()),
{"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."},
),
"gradient_checkpointing": (
IO.BOOLEAN,
{
"default": True,
"tooltip": "Use gradient checkpointing for training.",
}
),
"existing_lora": ( "existing_lora": (
folder_paths.get_filename_list("loras") + ["[None]"], folder_paths.get_filename_list("loras") + ["[None]"],
{ {
@ -501,6 +523,7 @@ class TrainLoraNode:
positive, positive,
batch_size, batch_size,
steps, steps,
grad_accumulation_steps,
learning_rate, learning_rate,
rank, rank,
optimizer, optimizer,
@ -508,6 +531,8 @@ class TrainLoraNode:
seed, seed,
training_dtype, training_dtype,
lora_dtype, lora_dtype,
algorithm,
gradient_checkpointing,
existing_lora, existing_lora,
): ):
mp = model.clone() mp = model.clone()
@ -558,10 +583,8 @@ class TrainLoraNode:
if existing_adapter is not None: if existing_adapter is not None:
break break
else: else:
# If no existing adapter found, use LoRA
# We will add algo option in the future
existing_adapter = None existing_adapter = None
adapter_cls = adapters[0] adapter_cls = adapter_maps[algorithm]
if existing_adapter is not None: if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype) train_adapter = existing_adapter.to_train().to(lora_dtype)
@ -615,8 +638,9 @@ class TrainLoraNode:
criterion = torch.nn.SmoothL1Loss() criterion = torch.nn.SmoothL1Loss()
# setup models # setup models
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): if gradient_checkpointing:
patch(m) for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
patch(m)
mp.model.requires_grad_(False) mp.model.requires_grad_(False)
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
@ -629,7 +653,8 @@ class TrainLoraNode:
optimizer, optimizer,
loss_callback=loss_callback, loss_callback=loss_callback,
batch_size=batch_size, batch_size=batch_size,
total_steps=steps, grad_acc=grad_accumulation_steps,
total_steps=steps*grad_accumulation_steps,
seed=seed, seed=seed,
training_dtype=dtype training_dtype=dtype
) )

View File

@ -19,14 +19,14 @@ class ConditioningStableAudio(io.ComfyNode):
node_id="ConditioningStableAudio_V3", node_id="ConditioningStableAudio_V3",
category="conditioning", category="conditioning",
inputs=[ inputs=[
io.Conditioning.Input(id="positive"), io.Conditioning.Input("positive"),
io.Conditioning.Input(id="negative"), io.Conditioning.Input("negative"),
io.Float.Input(id="seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1), io.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
io.Float.Input(id="seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1), io.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
], ],
outputs=[ outputs=[
io.Conditioning.Output(id="positive_out", display_name="positive"), io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(id="negative_out", display_name="negative"), io.Conditioning.Output(display_name="negative"),
], ],
) )
@ -49,7 +49,7 @@ class EmptyLatentAudio(io.ComfyNode):
node_id="EmptyLatentAudio_V3", node_id="EmptyLatentAudio_V3",
category="latent/audio", category="latent/audio",
inputs=[ inputs=[
io.Float.Input(id="seconds", default=47.6, min=1.0, max=1000.0, step=0.1), io.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
io.Int.Input( io.Int.Input(
id="batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." id="batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
), ),
@ -200,8 +200,8 @@ class VAEDecodeAudio(io.ComfyNode):
node_id="VAEDecodeAudio_V3", node_id="VAEDecodeAudio_V3",
category="latent/audio", category="latent/audio",
inputs=[ inputs=[
io.Latent.Input(id="samples"), io.Latent.Input("samples"),
io.Vae.Input(id="vae"), io.Vae.Input("vae"),
], ],
outputs=[io.Audio.Output()], outputs=[io.Audio.Output()],
) )
@ -222,8 +222,8 @@ class VAEEncodeAudio(io.ComfyNode):
node_id="VAEEncodeAudio_V3", node_id="VAEEncodeAudio_V3",
category="latent/audio", category="latent/audio",
inputs=[ inputs=[
io.Audio.Input(id="audio"), io.Audio.Input("audio"),
io.Vae.Input(id="vae"), io.Vae.Input("vae"),
], ],
outputs=[io.Latent.Output()], outputs=[io.Latent.Output()],
) )

View File

@ -13,7 +13,7 @@ class DifferentialDiffusion(io.ComfyNode):
display_name="Differential Diffusion _V3", display_name="Differential Diffusion _V3",
category="_for_testing", category="_for_testing",
inputs=[ inputs=[
io.Model.Input(id="model"), io.Model.Input("model"),
], ],
outputs=[ outputs=[
io.Model.Output(), io.Model.Output(),

View File

@ -32,10 +32,10 @@ class CLIPTextEncodeFlux(io.ComfyNode):
node_id="CLIPTextEncodeFlux_V3", node_id="CLIPTextEncodeFlux_V3",
category="advanced/conditioning/flux", category="advanced/conditioning/flux",
inputs=[ inputs=[
io.Clip.Input(id="clip"), io.Clip.Input("clip"),
io.String.Input(id="clip_l", multiline=True, dynamic_prompts=True), io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
io.String.Input(id="t5xxl", multiline=True, dynamic_prompts=True), io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
io.Float.Input(id="guidance", default=3.5, min=0.0, max=100.0, step=0.1), io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
], ],
outputs=[ outputs=[
io.Conditioning.Output(), io.Conditioning.Output(),
@ -58,7 +58,7 @@ class FluxDisableGuidance(io.ComfyNode):
category="advanced/conditioning/flux", category="advanced/conditioning/flux",
description="This node completely disables the guidance embed on Flux and Flux like models", description="This node completely disables the guidance embed on Flux and Flux like models",
inputs=[ inputs=[
io.Conditioning.Input(id="conditioning"), io.Conditioning.Input("conditioning"),
], ],
outputs=[ outputs=[
io.Conditioning.Output(), io.Conditioning.Output(),
@ -78,8 +78,8 @@ class FluxGuidance(io.ComfyNode):
node_id="FluxGuidance_V3", node_id="FluxGuidance_V3",
category="advanced/conditioning/flux", category="advanced/conditioning/flux",
inputs=[ inputs=[
io.Conditioning.Input(id="conditioning"), io.Conditioning.Input("conditioning"),
io.Float.Input(id="guidance", default=3.5, min=0.0, max=100.0, step=0.1), io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
], ],
outputs=[ outputs=[
io.Conditioning.Output(), io.Conditioning.Output(),
@ -100,7 +100,7 @@ class FluxKontextImageScale(io.ComfyNode):
category="advanced/conditioning/flux", category="advanced/conditioning/flux",
description="This node resizes the image to one that is more optimal for flux kontext.", description="This node resizes the image to one that is more optimal for flux kontext.",
inputs=[ inputs=[
io.Image.Input(id="image"), io.Image.Input("image"),
], ],
outputs=[ outputs=[
io.Image.Output(), io.Image.Output(),

View File

@ -35,11 +35,11 @@ class FreeU(io.ComfyNode):
node_id="FreeU_V3", node_id="FreeU_V3",
category="model_patches/unet", category="model_patches/unet",
inputs=[ inputs=[
io.Model.Input(id="model"), io.Model.Input("model"),
io.Float.Input(id="b1", default=1.1, min=0.0, max=10.0, step=0.01), io.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01),
io.Float.Input(id="b2", default=1.2, min=0.0, max=10.0, step=0.01), io.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01),
io.Float.Input(id="s1", default=0.9, min=0.0, max=10.0, step=0.01), io.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
io.Float.Input(id="s2", default=0.2, min=0.0, max=10.0, step=0.01), io.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
], ],
outputs=[ outputs=[
io.Model.Output(), io.Model.Output(),
@ -80,11 +80,11 @@ class FreeU_V2(io.ComfyNode):
node_id="FreeU_V2_V3", node_id="FreeU_V2_V3",
category="model_patches/unet", category="model_patches/unet",
inputs=[ inputs=[
io.Model.Input(id="model"), io.Model.Input("model"),
io.Float.Input(id="b1", default=1.3, min=0.0, max=10.0, step=0.01), io.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01),
io.Float.Input(id="b2", default=1.4, min=0.0, max=10.0, step=0.01), io.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01),
io.Float.Input(id="s1", default=0.9, min=0.0, max=10.0, step=0.01), io.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
io.Float.Input(id="s2", default=0.2, min=0.0, max=10.0, step=0.01), io.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
], ],
outputs=[ outputs=[
io.Model.Output(), io.Model.Output(),

View File

@ -65,12 +65,12 @@ class FreSca(io.ComfyNode):
category="_for_testing", category="_for_testing",
description="Applies frequency-dependent scaling to the guidance", description="Applies frequency-dependent scaling to the guidance",
inputs=[ inputs=[
io.Model.Input(id="model"), io.Model.Input("model"),
io.Float.Input(id="scale_low", default=1.0, min=0, max=10, step=0.01, io.Float.Input("scale_low", default=1.0, min=0, max=10, step=0.01,
tooltip="Scaling factor for low-frequency components"), tooltip="Scaling factor for low-frequency components"),
io.Float.Input(id="scale_high", default=1.25, min=0, max=10, step=0.01, io.Float.Input("scale_high", default=1.25, min=0, max=10, step=0.01,
tooltip="Scaling factor for high-frequency components"), tooltip="Scaling factor for high-frequency components"),
io.Int.Input(id="freq_cutoff", default=20, min=1, max=10000, step=1, io.Int.Input("freq_cutoff", default=20, min=1, max=10000, step=1,
tooltip="Number of frequency indices around center to consider as low-frequency"), tooltip="Number of frequency indices around center to consider as low-frequency"),
], ],
outputs=[ outputs=[

View File

@ -343,9 +343,9 @@ class GITSScheduler(io.ComfyNode):
node_id="GITSScheduler_V3", node_id="GITSScheduler_V3",
category="sampling/custom_sampling/schedulers", category="sampling/custom_sampling/schedulers",
inputs=[ inputs=[
io.Float.Input(id="coeff", default=1.20, min=0.80, max=1.50, step=0.05), io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05),
io.Int.Input(id="steps", default=10, min=2, max=1000), io.Int.Input("steps", default=10, min=2, max=1000),
io.Float.Input(id="denoise", default=1.0, min=0.0, max=1.0, step=0.01), io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
], ],
outputs=[ outputs=[
io.Sigmas.Output(), io.Sigmas.Output(),

View File

@ -0,0 +1,167 @@
from __future__ import annotations
import torch
import comfy.model_management
import node_helpers
import nodes
from comfy_api.v3 import io
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeHunyuanDiT_V3",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("bert", multiline=True, dynamic_prompts=True),
io.String.Input("mt5xl", multiline=True, dynamic_prompts=True),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, clip, bert, mt5xl):
tokens = clip.tokenize(bert)
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
class EmptyHunyuanLatentVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyHunyuanLatentVideo_V3",
category="latent/video",
inputs=[
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=25, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, width, height, length, batch_size):
latent = torch.zeros(
[batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
device=comfy.model_management.intermediate_device(),
)
return io.NodeOutput({"samples":latent})
class HunyuanImageToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="HunyuanImageToVideo_V3",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Vae.Input("vae"),
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"]),
io.Image.Input("start_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
latent = torch.zeros(
[batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
device=comfy.model_management.intermediate_device(),
)
out_latent = {}
if start_image is not None:
start_image = comfy.utils.common_upscale(
start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center"
).movedim(1, -1)
concat_latent_image = vae.encode(start_image)
mask = torch.ones(
(1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]),
device=start_image.device,
dtype=start_image.dtype,
)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
if guidance_type == "v1 (concat)":
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
elif guidance_type == "v2 (replace)":
cond = {'guiding_frame_index': 0}
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
out_latent["noise_mask"] = mask
elif guidance_type == "custom":
cond = {"ref_latent": concat_latent_image}
positive = node_helpers.conditioning_set_values(positive, cond)
out_latent["samples"] = latent
return io.NodeOutput(positive, out_latent)
class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeHunyuanVideo_ImageToVideo_V3",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.ClipVisionOutput.Input("clip_vision_output"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Int.Input(
"image_interleave",
default=2,
min=1,
max=512,
tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.",
),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, clip, clip_vision_output, prompt, image_interleave):
tokens = clip.tokenize(
prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V,
image_embeds=clip_vision_output.mm_projected,
image_interleave=image_interleave,
)
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
NODES_LIST = [
CLIPTextEncodeHunyuanDiT,
EmptyHunyuanLatentVideo,
HunyuanImageToVideo,
TextEncodeHunyuanVideo_ImageToVideo,
]

View File

@ -0,0 +1,136 @@
from __future__ import annotations
import logging
import torch
import comfy.utils
import folder_paths
from comfy_api.v3 import io
def load_hypernetwork_patch(path, strength):
sd = comfy.utils.load_torch_file(path, safe_load=True)
activation_func = sd.get('activation_func', 'linear')
is_layer_norm = sd.get('is_layer_norm', False)
use_dropout = sd.get('use_dropout', False)
activate_output = sd.get('activate_output', False)
last_layer_dropout = sd.get('last_layer_dropout', False)
valid_activation = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
"swish": torch.nn.Hardswish,
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
"softsign": torch.nn.Softsign,
"mish": torch.nn.Mish,
}
logging.error(
"Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(
path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout
)
)
out = {}
for d in sd:
try:
dim = int(d)
except Exception:
continue
output = []
for index in [0, 1]:
attn_weights = sd[dim][index]
keys = attn_weights.keys()
linears = filter(lambda a: a.endswith(".weight"), keys)
linears = list(map(lambda a: a[:-len(".weight")], linears))
layers = []
i = 0
while i < len(linears):
lin_name = linears[i]
last_layer = (i == (len(linears) - 1))
penultimate_layer = (i == (len(linears) - 2))
lin_weight = attn_weights['{}.weight'.format(lin_name)]
lin_bias = attn_weights['{}.bias'.format(lin_name)]
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
layers.append(layer)
if activation_func != "linear":
if (not last_layer) or (activate_output):
layers.append(valid_activation[activation_func]())
if is_layer_norm:
i += 1
ln_name = linears[i]
ln_weight = attn_weights['{}.weight'.format(ln_name)]
ln_bias = attn_weights['{}.bias'.format(ln_name)]
ln = torch.nn.LayerNorm(ln_weight.shape[0])
ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
layers.append(ln)
if use_dropout:
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
layers.append(torch.nn.Dropout(p=0.3))
i += 1
output.append(torch.nn.Sequential(*layers))
out[dim] = torch.nn.ModuleList(output)
class hypernetwork_patch:
def __init__(self, hypernet, strength):
self.hypernet = hypernet
self.strength = strength
def __call__(self, q, k, v, extra_options):
dim = k.shape[-1]
if dim in self.hypernet:
hn = self.hypernet[dim]
k = k + hn[0](k) * self.strength
v = v + hn[1](v) * self.strength
return q, k, v
def to(self, device):
for d in self.hypernet.keys():
self.hypernet[d] = self.hypernet[d].to(device)
return self
return hypernetwork_patch(out, strength)
class HypernetworkLoader(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="HypernetworkLoader_V3",
category="loaders",
inputs=[
io.Model.Input("model"),
io.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
io.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None:
model_hypernetwork.set_model_attn1_patch(patch)
model_hypernetwork.set_model_attn2_patch(patch)
return io.NodeOutput(model_hypernetwork)
NODES_LIST = [
HypernetworkLoader,
]

View File

@ -0,0 +1,95 @@
"""Taken from: https://github.com/tfernd/HyperTile/"""
from __future__ import annotations
import math
from einops import rearrange
from torch import randint
from comfy_api.v3 import io
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
min_value = min(min_value, value)
# All big divisors of value (inclusive)
divisors = [i for i in range(min_value, value + 1) if value % i == 0]
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
if len(ns) - 1 > 0:
idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
else:
idx = 0
return ns[idx]
class HyperTile(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="HyperTile_V3",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Int.Input("tile_size", default=256, min=1, max=2048),
io.Int.Input("swap_size", default=2, min=1, max=128),
io.Int.Input("max_depth", default=0, min=0, max=10),
io.Boolean.Input("scale_depth", default=False),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, tile_size, swap_size, max_depth, scale_depth):
latent_tile_size = max(32, tile_size) // 8
temp = None
def hypertile_in(q, k, v, extra_options):
nonlocal temp
model_chans = q.shape[-2]
orig_shape = extra_options['original_shape']
apply_to = []
for i in range(max_depth + 1):
apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i)))
if model_chans in apply_to:
shape = extra_options["original_shape"]
aspect_ratio = shape[-1] / shape[-2]
hw = q.size(1)
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, swap_size)
nw = random_divisor(w, latent_tile_size * factor, swap_size)
if nh * nw > 1:
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
temp = (nh, nw, h, w)
return q, k, v
return q, k, v
def hypertile_out(out, extra_options):
nonlocal temp
if temp is not None:
nh, nw, h, w = temp
temp = None
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
return out
m = model.clone()
m.set_model_attn1_patch(hypertile_in)
m.set_model_attn1_output_patch(hypertile_out)
return io.NodeOutput(m)
NODES_LIST = [
HyperTile,
]

View File

@ -0,0 +1,56 @@
from __future__ import annotations
import torch
from comfy_api.v3 import io
class InstructPixToPixConditioning(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="InstructPixToPixConditioning_V3",
category="conditioning/instructpix2pix",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Image.Input("pixels"),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, positive, negative, pixels, vae):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
concat_latent = vae.encode(pixels)
out_latent = {}
out_latent["samples"] = torch.zeros_like(concat_latent)
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
d["concat_latent_image"] = concat_latent
n = [t[0], d]
c.append(n)
out.append(c)
return io.NodeOutput(out[0], out[1], out_latent)
NODES_LIST = [
InstructPixToPixConditioning,
]

View File

@ -24,8 +24,8 @@ class LatentAdd(io.ComfyNode):
node_id="LatentAdd_V3", node_id="LatentAdd_V3",
category="latent/advanced", category="latent/advanced",
inputs=[ inputs=[
io.Latent.Input(id="samples1"), io.Latent.Input("samples1"),
io.Latent.Input(id="samples2"), io.Latent.Input("samples2"),
], ],
outputs=[ outputs=[
io.Latent.Output(), io.Latent.Output(),
@ -52,8 +52,8 @@ class LatentApplyOperation(io.ComfyNode):
category="latent/advanced/operations", category="latent/advanced/operations",
is_experimental=True, is_experimental=True,
inputs=[ inputs=[
io.Latent.Input(id="samples"), io.Latent.Input("samples"),
io.LatentOperation.Input(id="operation"), io.LatentOperation.Input("operation"),
], ],
outputs=[ outputs=[
io.Latent.Output(), io.Latent.Output(),
@ -77,8 +77,8 @@ class LatentApplyOperationCFG(io.ComfyNode):
category="latent/advanced/operations", category="latent/advanced/operations",
is_experimental=True, is_experimental=True,
inputs=[ inputs=[
io.Model.Input(id="model"), io.Model.Input("model"),
io.LatentOperation.Input(id="operation"), io.LatentOperation.Input("operation"),
], ],
outputs=[ outputs=[
io.Model.Output(), io.Model.Output(),
@ -108,8 +108,8 @@ class LatentBatch(io.ComfyNode):
node_id="LatentBatch_V3", node_id="LatentBatch_V3",
category="latent/batch", category="latent/batch",
inputs=[ inputs=[
io.Latent.Input(id="samples1"), io.Latent.Input("samples1"),
io.Latent.Input(id="samples2"), io.Latent.Input("samples2"),
], ],
outputs=[ outputs=[
io.Latent.Output(), io.Latent.Output(),
@ -137,8 +137,8 @@ class LatentBatchSeedBehavior(io.ComfyNode):
node_id="LatentBatchSeedBehavior_V3", node_id="LatentBatchSeedBehavior_V3",
category="latent/advanced", category="latent/advanced",
inputs=[ inputs=[
io.Latent.Input(id="samples"), io.Latent.Input("samples"),
io.Combo.Input(id="seed_behavior", options=["random", "fixed"], default="fixed"), io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"),
], ],
outputs=[ outputs=[
io.Latent.Output(), io.Latent.Output(),
@ -166,9 +166,9 @@ class LatentInterpolate(io.ComfyNode):
node_id="LatentInterpolate_V3", node_id="LatentInterpolate_V3",
category="latent/advanced", category="latent/advanced",
inputs=[ inputs=[
io.Latent.Input(id="samples1"), io.Latent.Input("samples1"),
io.Latent.Input(id="samples2"), io.Latent.Input("samples2"),
io.Float.Input(id="ratio", default=1.0, min=0.0, max=1.0, step=0.01), io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
], ],
outputs=[ outputs=[
io.Latent.Output(), io.Latent.Output(),
@ -205,8 +205,8 @@ class LatentMultiply(io.ComfyNode):
node_id="LatentMultiply_V3", node_id="LatentMultiply_V3",
category="latent/advanced", category="latent/advanced",
inputs=[ inputs=[
io.Latent.Input(id="samples"), io.Latent.Input("samples"),
io.Float.Input(id="multiplier", default=1.0, min=-10.0, max=10.0, step=0.01), io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
], ],
outputs=[ outputs=[
io.Latent.Output(), io.Latent.Output(),
@ -230,9 +230,9 @@ class LatentOperationSharpen(io.ComfyNode):
category="latent/advanced/operations", category="latent/advanced/operations",
is_experimental=True, is_experimental=True,
inputs=[ inputs=[
io.Int.Input(id="sharpen_radius", default=9, min=1, max=31, step=1), io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1),
io.Float.Input(id="sigma", default=1.0, min=0.1, max=10.0, step=0.1), io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
io.Float.Input(id="alpha", default=0.1, min=0.0, max=5.0, step=0.01), io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01),
], ],
outputs=[ outputs=[
io.LatentOperation.Output(), io.LatentOperation.Output(),
@ -272,7 +272,7 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
category="latent/advanced/operations", category="latent/advanced/operations",
is_experimental=True, is_experimental=True,
inputs=[ inputs=[
io.Float.Input(id="multiplier", default=1.0, min=0.0, max=100.0, step=0.01), io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
], ],
outputs=[ outputs=[
io.LatentOperation.Output(), io.LatentOperation.Output(),
@ -306,8 +306,8 @@ class LatentSubtract(io.ComfyNode):
node_id="LatentSubtract_V3", node_id="LatentSubtract_V3",
category="latent/advanced", category="latent/advanced",
inputs=[ inputs=[
io.Latent.Input(id="samples1"), io.Latent.Input("samples1"),
io.Latent.Input(id="samples2"), io.Latent.Input("samples2"),
], ],
outputs=[ outputs=[
io.Latent.Output(), io.Latent.Output(),

View File

@ -0,0 +1,180 @@
from __future__ import annotations
import os
from pathlib import Path
import folder_paths
import nodes
from comfy_api.input_impl import VideoFromFile
from comfy_api.v3 import io, ui
def normalize_path(path):
return path.replace("\\", "/")
class Load3D(io.ComfyNode):
@classmethod
def define_schema(cls):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
input_path = Path(input_dir)
base_path = Path(folder_paths.get_input_directory())
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {".gltf", ".glb", ".obj", ".fbx", ".stl"}
]
return io.Schema(
node_id="Load3D_V3",
display_name="Load 3D _V3",
category="3d",
is_experimental=True,
inputs=[
io.Combo.Input("model_file", options=sorted(files), upload=io.UploadType.model),
io.Load3D.Input("image"),
io.Int.Input("width", default=1024, min=1, max=4096, step=1),
io.Int.Input("height", default=1024, min=1, max=4096, step=1),
],
outputs=[
io.Image.Output(display_name="image"),
io.Mask.Output(display_name="mask"),
io.String.Output(display_name="mesh_path"),
io.Image.Output(display_name="normal"),
io.Image.Output(display_name="lineart"),
io.Load3DCamera.Output(display_name="camera_info"),
io.Video.Output(display_name="recording_video"),
],
)
@classmethod
def execute(cls, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image["image"])
mask_path = folder_paths.get_annotated_filepath(image["mask"])
normal_path = folder_paths.get_annotated_filepath(image["normal"])
lineart_path = folder_paths.get_annotated_filepath(image["lineart"])
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
video = None
if image["recording"] != "":
recording_video_path = folder_paths.get_annotated_filepath(image["recording"])
video = VideoFromFile(recording_video_path)
return io.NodeOutput(
output_image, output_mask, model_file, normal_image, lineart_image, image["camera_info"], video
)
class Load3DAnimation(io.ComfyNode):
@classmethod
def define_schema(cls):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
input_path = Path(input_dir)
base_path = Path(folder_paths.get_input_directory())
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {".gltf", ".glb", ".fbx"}
]
return io.Schema(
node_id="Load3DAnimation_V3",
display_name="Load 3D - Animation _V3",
category="3d",
is_experimental=True,
inputs=[
io.Combo.Input("model_file", options=sorted(files), upload=io.UploadType.model),
io.Load3DAnimation.Input("image"),
io.Int.Input("width", default=1024, min=1, max=4096, step=1),
io.Int.Input("height", default=1024, min=1, max=4096, step=1),
],
outputs=[
io.Image.Output(display_name="image"),
io.Mask.Output(display_name="mask"),
io.String.Output(display_name="mesh_path"),
io.Image.Output(display_name="normal"),
io.Load3DCamera.Output(display_name="camera_info"),
io.Video.Output(display_name="recording_video"),
],
)
@classmethod
def execute(cls, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image["image"])
mask_path = folder_paths.get_annotated_filepath(image["mask"])
normal_path = folder_paths.get_annotated_filepath(image["normal"])
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
video = None
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image["recording"])
video = VideoFromFile(recording_video_path)
return io.NodeOutput(output_image, output_mask, model_file, normal_image, image["camera_info"], video)
class Preview3D(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Preview3D_V3", # frontend expects "Preview3D" to work
display_name="Preview 3D _V3",
category="3d",
is_experimental=True,
is_output_node=True,
inputs=[
io.String.Input("model_file", default="", multiline=False),
io.Load3DCamera.Input("camera_info", optional=True),
],
outputs=[],
)
@classmethod
def execute(cls, model_file, camera_info=None):
return io.NodeOutput(ui=ui.PreviewUI3D(model_file, camera_info, cls=cls))
class Preview3DAnimation(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Preview3DAnimation_V3", # frontend expects "Preview3DAnimation" to work
display_name="Preview 3D - Animation _V3",
category="3d",
is_experimental=True,
is_output_node=True,
inputs=[
io.String.Input("model_file", default="", multiline=False),
io.Load3DCamera.Input("camera_info", optional=True),
],
outputs=[],
)
@classmethod
def execute(cls, model_file, camera_info=None):
return io.NodeOutput(ui=ui.PreviewUI3D(model_file, camera_info, cls=cls))
NODES_LIST = [
Load3D,
Load3DAnimation,
Preview3D,
Preview3DAnimation,
]

View File

@ -0,0 +1,138 @@
from __future__ import annotations
import logging
import os
from enum import Enum
import torch
import comfy.model_management
import comfy.utils
import folder_paths
from comfy_api.v3 import io
CLAMP_QUANTILE = 0.99
def extract_lora(diff, rank):
conv2d = (len(diff.shape) == 4)
kernel_size = None if not conv2d else diff.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = diff.size()[0:2]
rank = min(rank, in_dim, out_dim)
if conv2d:
if conv2d_3x3:
diff = diff.flatten(start_dim=1)
else:
diff = diff.squeeze()
U, S, Vh = torch.linalg.svd(diff.float())
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
return (U, Vh)
class LORAType(Enum):
STANDARD = 0
FULL_DIFF = 1
LORA_TYPES = {
"standard": LORAType.STANDARD,
"full_diff": LORAType.FULL_DIFF,
}
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
for k in sd:
if k.endswith(".weight"):
weight_diff = sd[k]
if lora_type == LORAType.STANDARD:
if weight_diff.ndim < 2:
if bias_diff:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
continue
try:
out = extract_lora(weight_diff, rank)
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
except Exception:
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
elif lora_type == LORAType.FULL_DIFF:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
elif bias_diff and k.endswith(".bias"):
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
return output_sd
class LoraSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoraSave_V3",
display_name="Extract and Save Lora _V3",
category="_for_testing",
is_output_node=True,
inputs=[
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
io.Int.Input("rank", default=8, min=1, max=4096, step=1),
io.Combo.Input("lora_type", options=list(LORA_TYPES.keys())),
io.Boolean.Input("bias_diff", default=True),
io.Model.Input(
id="model_diff", optional=True, tooltip="The ModelSubtract output to be converted to a lora."
),
io.Clip.Input(
id="text_encoder_diff", optional=True, tooltip="The CLIPSubtract output to be converted to a lora."
),
],
outputs=[],
is_experimental=True,
)
@classmethod
def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
if model_diff is None and text_encoder_diff is None:
return io.NodeOutput()
lora_type = LORA_TYPES.get(lora_type)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix, folder_paths.get_output_directory()
)
output_sd = {}
if model_diff is not None:
output_sd = calc_lora_model(
model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff
)
if text_encoder_diff is not None:
output_sd = calc_lora_model(
text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff
)
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
return io.NodeOutput()
NODES_LIST = [
LoraSave,
]

File diff suppressed because one or more lines are too long

View File

@ -93,10 +93,10 @@ class EmptyLTXVLatentVideo(io.ComfyNode):
node_id="EmptyLTXVLatentVideo_V3", node_id="EmptyLTXVLatentVideo_V3",
category="latent/video/ltxv", category="latent/video/ltxv",
inputs=[ inputs=[
io.Int.Input(id="width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32), io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input(id="height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32), io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input(id="length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8), io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input(id="batch_size", default=1, min=1, max=4096), io.Int.Input("batch_size", default=1, min=1, max=4096),
], ],
outputs=[ outputs=[
io.Latent.Output(), io.Latent.Output(),
@ -122,10 +122,10 @@ class LTXVAddGuide(io.ComfyNode):
node_id="LTXVAddGuide_V3", node_id="LTXVAddGuide_V3",
category="conditioning/video_models", category="conditioning/video_models",
inputs=[ inputs=[
io.Conditioning.Input(id="positive"), io.Conditioning.Input("positive"),
io.Conditioning.Input(id="negative"), io.Conditioning.Input("negative"),
io.Vae.Input(id="vae"), io.Vae.Input("vae"),
io.Latent.Input(id="latent"), io.Latent.Input("latent"),
io.Image.Input( io.Image.Input(
id="image", id="image",
tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. " tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. "
@ -141,12 +141,12 @@ class LTXVAddGuide(io.ComfyNode):
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded " "For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
"down to the nearest multiple of 8. Negative values are counted from the end of the video.", "down to the nearest multiple of 8. Negative values are counted from the end of the video.",
), ),
io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0, step=0.01), io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
], ],
outputs=[ outputs=[
io.Conditioning.Output(id="positive_out", display_name="positive"), io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(id="negative_out", display_name="negative"), io.Conditioning.Output(display_name="negative"),
io.Latent.Output(id="latent_out", display_name="latent"), io.Latent.Output(display_name="latent"),
], ],
) )
@ -282,13 +282,13 @@ class LTXVConditioning(io.ComfyNode):
node_id="LTXVConditioning_V3", node_id="LTXVConditioning_V3",
category="conditioning/video_models", category="conditioning/video_models",
inputs=[ inputs=[
io.Conditioning.Input(id="positive"), io.Conditioning.Input("positive"),
io.Conditioning.Input(id="negative"), io.Conditioning.Input("negative"),
io.Float.Input(id="frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01), io.Float.Input("frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01),
], ],
outputs=[ outputs=[
io.Conditioning.Output(id="positive_out", display_name="positive"), io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(id="negative_out", display_name="negative"), io.Conditioning.Output(display_name="negative"),
], ],
) )
@ -306,14 +306,14 @@ class LTXVCropGuides(io.ComfyNode):
node_id="LTXVCropGuides_V3", node_id="LTXVCropGuides_V3",
category="conditioning/video_models", category="conditioning/video_models",
inputs=[ inputs=[
io.Conditioning.Input(id="positive"), io.Conditioning.Input("positive"),
io.Conditioning.Input(id="negative"), io.Conditioning.Input("negative"),
io.Latent.Input(id="latent"), io.Latent.Input("latent"),
], ],
outputs=[ outputs=[
io.Conditioning.Output(id="positive_out", display_name="positive"), io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(id="negative_out", display_name="negative"), io.Conditioning.Output(display_name="negative"),
io.Latent.Output(id="latent_out", display_name="latent"), io.Latent.Output(display_name="latent"),
], ],
) )
@ -342,19 +342,19 @@ class LTXVImgToVideo(io.ComfyNode):
node_id="LTXVImgToVideo_V3", node_id="LTXVImgToVideo_V3",
category="conditioning/video_models", category="conditioning/video_models",
inputs=[ inputs=[
io.Conditioning.Input(id="positive"), io.Conditioning.Input("positive"),
io.Conditioning.Input(id="negative"), io.Conditioning.Input("negative"),
io.Vae.Input(id="vae"), io.Vae.Input("vae"),
io.Image.Input(id="image"), io.Image.Input("image"),
io.Int.Input(id="width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32), io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input(id="height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32), io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input(id="length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8), io.Int.Input("length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input(id="batch_size", default=1, min=1, max=4096), io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0), io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
], ],
outputs=[ outputs=[
io.Conditioning.Output(id="positive_out", display_name="positive"), io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(id="negative_out", display_name="negative"), io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"), io.Latent.Output(display_name="latent"),
], ],
) )
@ -390,13 +390,13 @@ class LTXVPreprocess(io.ComfyNode):
node_id="LTXVPreprocess_V3", node_id="LTXVPreprocess_V3",
category="image", category="image",
inputs=[ inputs=[
io.Image.Input(id="image"), io.Image.Input("image"),
io.Int.Input( io.Int.Input(
id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image." id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image."
), ),
], ],
outputs=[ outputs=[
io.Image.Output(id="output_image", display_name="output_image"), io.Image.Output(display_name="output_image"),
], ],
) )
@ -415,9 +415,9 @@ class LTXVScheduler(io.ComfyNode):
node_id="LTXVScheduler_V3", node_id="LTXVScheduler_V3",
category="sampling/custom_sampling/schedulers", category="sampling/custom_sampling/schedulers",
inputs=[ inputs=[
io.Int.Input(id="steps", default=20, min=1, max=10000), io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input(id="max_shift", default=2.05, min=0.0, max=100.0, step=0.01), io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
io.Float.Input(id="base_shift", default=0.95, min=0.0, max=100.0, step=0.01), io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
io.Boolean.Input( io.Boolean.Input(
id="stretch", id="stretch",
default=True, default=True,
@ -431,7 +431,7 @@ class LTXVScheduler(io.ComfyNode):
step=0.01, step=0.01,
tooltip="The terminal value of the sigmas after stretching.", tooltip="The terminal value of the sigmas after stretching.",
), ),
io.Latent.Input(id="latent", optional=True), io.Latent.Input("latent", optional=True),
], ],
outputs=[ outputs=[
io.Sigmas.Output(), io.Sigmas.Output(),
@ -478,10 +478,10 @@ class ModelSamplingLTXV(io.ComfyNode):
node_id="ModelSamplingLTXV_V3", node_id="ModelSamplingLTXV_V3",
category="advanced/model", category="advanced/model",
inputs=[ inputs=[
io.Model.Input(id="model"), io.Model.Input("model"),
io.Float.Input(id="max_shift", default=2.05, min=0.0, max=100.0, step=0.01), io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
io.Float.Input(id="base_shift", default=0.95, min=0.0, max=100.0, step=0.01), io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
io.Latent.Input(id="latent", optional=True), io.Latent.Input("latent", optional=True),
], ],
outputs=[ outputs=[
io.Model.Output(), io.Model.Output(),

View File

@ -0,0 +1,116 @@
from __future__ import annotations
import torch
from comfy_api.v3 import io
class CLIPTextEncodeLumina2(io.ComfyNode):
SYSTEM_PROMPT = {
"superior": "You are an assistant designed to generate superior images with the superior "
"degree of image-text alignment based on textual prompts or user prompts.",
"alignment": "You are an assistant designed to generate high-quality images with the "
"highest degree of image-text alignment based on textual prompts."
}
SYSTEM_PROMPT_TIP = "Lumina2 provide two types of system prompts:" \
"Superior: You are an assistant designed to generate superior images with the superior "\
"degree of image-text alignment based on textual prompts or user prompts. "\
"Alignment: You are an assistant designed to generate high-quality images with the highest "\
"degree of image-text alignment based on textual prompts."
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeLumina2_V3",
display_name="CLIP Text Encode for Lumina2 _V3",
category="conditioning",
description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "
"that can be used to guide the diffusion model towards generating specific images.",
inputs=[
io.Combo.Input("system_prompt", options=list(cls.SYSTEM_PROMPT.keys()), tooltip=cls.SYSTEM_PROMPT_TIP),
io.String.Input("user_prompt", multiline=True, dynamic_prompts=True, tooltip="The text to be encoded."),
io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."),
],
outputs=[
io.Conditioning.Output(tooltip="A conditioning containing the embedded text used to guide the diffusion model."),
],
)
@classmethod
def execute(cls, system_prompt, user_prompt, clip):
if clip is None:
raise RuntimeError(
"ERROR: clip input is invalid: None\n\n"
"If the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model."
)
system_prompt = cls.SYSTEM_PROMPT[system_prompt]
prompt = f'{system_prompt} <Prompt Start> {user_prompt}'
tokens = clip.tokenize(prompt)
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
class RenormCFG(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="RenormCFG_V3",
category="advanced/model",
inputs=[
io.Model.Input("model"),
io.Float.Input("cfg_trunc", default=100, min=0.0, max=100.0, step=0.01),
io.Float.Input("renorm_cfg", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, cfg_trunc, renorm_cfg):
def renorm_cfg_func(args):
cond_denoised = args["cond_denoised"]
uncond_denoised = args["uncond_denoised"]
cond_scale = args["cond_scale"]
timestep = args["timestep"]
x_orig = args["input"]
in_channels = model.model.diffusion_model.in_channels
if timestep[0] < cfg_trunc:
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
half_eps = uncond_eps + cond_scale * (cond_eps - uncond_eps)
half_rest = cond_rest
if float(renorm_cfg) > 0.0:
ori_pos_norm = torch.linalg.vector_norm(
cond_eps,
dim=tuple(range(1, len(cond_eps.shape))),
keepdim=True
)
max_new_norm = ori_pos_norm * float(renorm_cfg)
new_pos_norm = torch.linalg.vector_norm(
half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
)
if new_pos_norm >= max_new_norm:
half_eps = half_eps * (max_new_norm / new_pos_norm)
else:
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
half_eps = cond_eps
half_rest = cond_rest
cfg_result = torch.cat([half_eps, half_rest], dim=1)
# cfg_result = uncond_denoised + (cond_denoised - uncond_denoised) * cond_scale
return x_orig - cfg_result
m = model.clone()
m.set_model_sampler_cfg_function(renorm_cfg_func)
return io.NodeOutput(m)
NODES_LIST = [
CLIPTextEncodeLumina2,
RenormCFG,
]

View File

@ -23,12 +23,12 @@ class ImageRGBToYUV(io.ComfyNode):
node_id="ImageRGBToYUV_V3", node_id="ImageRGBToYUV_V3",
category="image/batch", category="image/batch",
inputs=[ inputs=[
io.Image.Input(id="image"), io.Image.Input("image"),
], ],
outputs=[ outputs=[
io.Image.Output(id="Y", display_name="Y"), io.Image.Output(display_name="Y"),
io.Image.Output(id="U", display_name="U"), io.Image.Output(display_name="U"),
io.Image.Output(id="V", display_name="V"), io.Image.Output(display_name="V"),
], ],
) )
@ -45,9 +45,9 @@ class ImageYUVToRGB(io.ComfyNode):
node_id="ImageYUVToRGB_V3", node_id="ImageYUVToRGB_V3",
category="image/batch", category="image/batch",
inputs=[ inputs=[
io.Image.Input(id="Y"), io.Image.Input("Y"),
io.Image.Input(id="U"), io.Image.Input("U"),
io.Image.Input(id="V"), io.Image.Input("V"),
], ],
outputs=[ outputs=[
io.Image.Output(), io.Image.Output(),
@ -68,9 +68,9 @@ class Morphology(io.ComfyNode):
display_name="ImageMorphology _V3", display_name="ImageMorphology _V3",
category="image/postprocessing", category="image/postprocessing",
inputs=[ inputs=[
io.Image.Input(id="image"), io.Image.Input("image"),
io.Combo.Input(id="operation", options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"]), io.Combo.Input("operation", options=["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"]),
io.Int.Input(id="kernel_size", default=3, min=3, max=999, step=1), io.Int.Input("kernel_size", default=3, min=3, max=999, step=1),
], ],
outputs=[ outputs=[
io.Image.Output(), io.Image.Output(),

View File

@ -33,9 +33,9 @@ class OptimalStepsScheduler(io.ComfyNode):
node_id="OptimalStepsScheduler_V3", node_id="OptimalStepsScheduler_V3",
category="sampling/custom_sampling/schedulers", category="sampling/custom_sampling/schedulers",
inputs=[ inputs=[
io.Combo.Input(id="model_type", options=["FLUX", "Wan", "Chroma"]), io.Combo.Input("model_type", options=["FLUX", "Wan", "Chroma"]),
io.Int.Input(id="steps", default=20, min=3, max=1000), io.Int.Input("steps", default=20, min=3, max=1000),
io.Float.Input(id="denoise", default=1.0, min=0.0, max=1.0, step=0.01), io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
], ],
outputs=[ outputs=[
io.Sigmas.Output(), io.Sigmas.Output(),

View File

@ -17,8 +17,8 @@ class PerturbedAttentionGuidance(io.ComfyNode):
node_id="PerturbedAttentionGuidance_V3", node_id="PerturbedAttentionGuidance_V3",
category="model_patches/unet", category="model_patches/unet",
inputs=[ inputs=[
io.Model.Input(id="model"), io.Model.Input("model"),
io.Float.Input(id="scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01), io.Float.Input("scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01),
], ],
outputs=[ outputs=[
io.Model.Output(), io.Model.Output(),

View File

@ -88,12 +88,12 @@ class PerpNegGuider(io.ComfyNode):
node_id="PerpNegGuider_V3", node_id="PerpNegGuider_V3",
category="_for_testing", category="_for_testing",
inputs=[ inputs=[
io.Model.Input(id="model"), io.Model.Input("model"),
io.Conditioning.Input(id="positive"), io.Conditioning.Input("positive"),
io.Conditioning.Input(id="negative"), io.Conditioning.Input("negative"),
io.Conditioning.Input(id="empty_conditioning"), io.Conditioning.Input("empty_conditioning"),
io.Float.Input(id="cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
io.Float.Input(id="neg_scale", default=1.0, min=0.0, max=100.0, step=0.01), io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01),
], ],
outputs=[ outputs=[
io.Guider.Output(), io.Guider.Output(),

View File

@ -0,0 +1,70 @@
"""TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)"""
from __future__ import annotations
import torch
from comfy_api.v3 import io
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
"""Drop tangential components from uncond score to align with cond score."""
# (B, 1, ...)
batch_num = cond_score.shape[0]
cond_score_flat = cond_score.reshape(batch_num, 1, -1).float()
uncond_score_flat = uncond_score.reshape(batch_num, 1, -1).float()
# Score matrix A (B, 2, ...)
score_matrix = torch.cat((uncond_score_flat, cond_score_flat), dim=1)
try:
_, _, Vh = torch.linalg.svd(score_matrix, full_matrices=False)
except RuntimeError:
# Fallback to CPU
_, _, Vh = torch.linalg.svd(score_matrix.cpu(), full_matrices=False)
# Drop the tangential components
v1 = Vh[:, 0:1, :].to(uncond_score_flat.device) # (B, 1, ...)
uncond_score_td = (uncond_score_flat @ v1.transpose(-2, -1)) * v1
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
class TCFG(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TCFG_V3",
display_name="Tangential Damping CFG _V3",
category="advanced/guidance",
description="TCFG Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality.",
inputs=[
io.Model.Input("model"),
],
outputs=[
io.Model.Output(display_name="patched_model"),
],
)
@classmethod
def execute(cls, model):
m = model.clone()
def tangential_damping_cfg(args):
# Assume [cond, uncond, ...]
x = args["input"]
conds_out = args["conds_out"]
if len(conds_out) <= 1 or None in args["conds"][:2]:
# Skip when either cond or uncond is None
return conds_out
cond_pred = conds_out[0]
uncond_pred = conds_out[1]
uncond_td = score_tangential_damping(x - cond_pred, x - uncond_pred)
uncond_pred_td = x - uncond_td
return [cond_pred, uncond_pred_td] + conds_out[2:]
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
return io.NodeOutput(m)
NODES_LIST = [
TCFG,
]

View File

@ -0,0 +1,190 @@
"""Taken from: https://github.com/dbolya/tomesd"""
from __future__ import annotations
import math
from typing import Callable, Tuple
import torch
from comfy_api.v3 import io
def do_nothing(x: torch.Tensor, mode:str=None):
return x
def mps_gather_workaround(input, dim, index):
if input.shape[-1] == 1:
return torch.gather(
input.unsqueeze(-1),
dim - 1 if dim < 0 else dim,
index.unsqueeze(-1)
).squeeze(-1)
return torch.gather(input, dim, index)
def bipartite_soft_matching_random2d(
metric: torch.Tensor,w: int, h: int, sx: int, sy: int, r: int, no_rand: bool = False
) -> Tuple[Callable, Callable]:
"""
Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Args:
- metric [B, N, C]: metric to use for similarity
- w: image width in tokens
- h: image height in tokens
- sx: stride in the x dimension for dst, must divide w
- sy: stride in the y dimension for dst, must divide h
- r: number of tokens to remove (by merging)
- no_rand: if true, disable randomness (use top left corner only)
"""
B, N, _ = metric.shape
if r <= 0 or w == 1 or h == 1:
return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
with torch.no_grad():
hsy, wsx = h // sy, w // sx
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
if no_rand:
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
else:
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
# Image is not divisible by sx or sy so we need to move it into a new buffer
if (hsy * sy) < h or (wsx * sx) < w:
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
else:
idx_buffer = idx_buffer_view
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
# We're finished with these
del idx_buffer, idx_buffer_view
# rand_idx is currently dst|src, so split them
num_dst = hsy * wsx
a_idx = rand_idx[:, num_dst:, :] # src
b_idx = rand_idx[:, :num_dst, :] # dst
def split(x):
C = x.shape[-1]
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
return src, dst
# Cosine similarity between A and B
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = split(metric)
scores = a @ b.transpose(-1, -2)
# Can't reduce more than the # tokens in src
r = min(a.shape[1], r)
# Find the most similar greedily
node_max, node_idx = scores.max(dim=-1)
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
src_idx = edge_idx[..., :r, :] # Merged Tokens
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x)
n, t1, c = src.shape
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
return torch.cat([unm, dst], dim=1)
def unmerge(x: torch.Tensor) -> torch.Tensor:
unm_len = unm_idx.shape[1]
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
_, _, c = unm.shape
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
# Combine back to the original shape
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
return out
return merge, unmerge
def get_functions(x, ratio, original_shape):
b, c, original_h, original_w = original_shape
original_tokens = original_h * original_w
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
stride_x = 2
stride_y = 2
max_downsample = 1
if downsample <= max_downsample:
w = int(math.ceil(original_w / downsample))
h = int(math.ceil(original_h / downsample))
r = int(x.shape[1] * ratio)
no_rand = False
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
return m, u
def nothing(y):
return y
return nothing, nothing
class TomePatchModel(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TomePatchModel_V3",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, ratio):
u = None
def tomesd_m(q, k, v, extra_options):
nonlocal u
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
#however from my basic testing it seems that using q instead gives better results
m, u = get_functions(q, ratio, extra_options["original_shape"])
return m(q), k, v
def tomesd_u(n, extra_options):
return u(n)
m = model.clone()
m.set_model_attn1_patch(tomesd_m)
m.set_model_attn1_output_patch(tomesd_u)
return io.NodeOutput(m)
NODES_LIST = [
TomePatchModel,
]

View File

@ -0,0 +1,32 @@
from __future__ import annotations
from comfy_api.torch_helpers import set_torch_compile_wrapper
from comfy_api.v3 import io
class TorchCompileModel(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TorchCompileModel_V3",
category="_for_testing",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.Combo.Input("backend", options=["inductor", "cudagraphs"]),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, backend):
m = model.clone()
set_torch_compile_wrapper(model=m, backend=backend)
return io.NodeOutput(m)
NODES_LIST = [
TorchCompileModel,
]

View File

@ -0,0 +1,666 @@
from __future__ import annotations
import logging
import os
import numpy as np
import safetensors
import torch
import torch.utils.checkpoint
import tqdm
from PIL import Image, ImageDraw, ImageFont
import comfy.model_management
import comfy.samplers
import comfy.sd
import comfy.utils
import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
from comfy.weight_adapter import adapter_maps, adapters
from comfy_api.v3 import io, ui
def make_batch_extra_option_dict(d, indicies, full_size=None):
new_dict = {}
for k, v in d.items():
newv = v
if isinstance(v, dict):
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
elif isinstance(v, torch.Tensor):
if full_size is None or v.size(0) == full_size:
newv = v[indicies]
elif isinstance(v, (list, tuple)) and len(v) == full_size:
newv = [v[i] for i in indicies]
new_dict[k] = newv
return new_dict
class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
self.loss_fn = loss_fn
self.optimizer = optimizer
self.loss_callback = loss_callback
self.batch_size = batch_size
self.total_steps = total_steps
self.grad_acc = grad_acc
self.seed = seed
self.training_dtype = training_dtype
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0)
torch.cuda.empty_cache()
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
) for _ in range(min(self.batch_size, dataset_size))
]
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
xt = model_wrap.inner_model.model_sampling.noise_scaling(
batch_sigmas,
batch_noise,
batch_latent,
False
)
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(batch_sigmas),
torch.zeros_like(batch_noise),
batch_latent,
False
)
model_wrap.conds["positive"] = [
cond[i] for i in indicies
]
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
with torch.autocast(xt.device.type, dtype=self.training_dtype):
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
loss = self.loss_fn(x0_pred, x0)
loss.backward()
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
if (i + 1) % self.grad_acc == 0:
self.optimizer.step()
self.optimizer.zero_grad()
torch.cuda.empty_cache()
return torch.zeros_like(latent_image)
class BiasDiff(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.bias = bias
def __call__(self, b):
org_dtype = b.dtype
return (b.to(self.bias) + self.bias).to(org_dtype)
def passive_memory_usage(self):
return self.bias.nelement() * self.bias.element_size()
def move_to(self, device):
self.to(device=device)
return self.passive_memory_usage()
def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
"""Utility function to load and process a list of images.
Args:
image_files: List of image filenames
input_dir: Base directory containing the images
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
Returns:
torch.Tensor: Batch of processed images
"""
if not image_files:
raise ValueError("No valid images found in input")
output_images = []
for file in image_files:
image_path = os.path.join(input_dir, file)
img = node_helpers.pillow(Image.open, image_path)
if img.mode == "I":
img = img.point(lambda i: i * (1 / 255))
img = img.convert("RGB")
if w is None and h is None:
w, h = img.size[0], img.size[1]
# Resize image to first image
if img.size[0] != w or img.size[1] != h:
if resize_method == "Stretch":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "Crop":
img = img.crop((0, 0, w, h))
elif resize_method == "Pad":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "None":
raise ValueError(
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
)
img_array = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array)[None,]
output_images.append(img_tensor)
return torch.cat(output_images, dim=0)
def draw_loss_graph(loss_map, steps):
width, height = 500, 300
img = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_map.values()]
prev_point = (0, height - int(scaled_loss[0] * height))
for i, l_v in enumerate(scaled_loss[1:], start=1):
x = int(i / (steps - 1) * width)
y = height - int(l_v * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
return img
def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None):
if result is None:
result = []
elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
result.append(model)
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
return result
name = name or "root"
for next_name, child in model.named_children():
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
return result
def patch(m):
if not hasattr(m, "forward"):
return
org_forward = m.forward
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(
fwd, args, kwargs, use_reentrant=False
)
m.org_forward = org_forward
m.forward = checkpointing_fwd
def unpatch(m):
if hasattr(m, "org_forward"):
m.forward = m.org_forward
del m.org_forward
class LoadImageSetFromFolderNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadImageSetFromFolderNode_V3",
display_name="Load Image Dataset from Folder _V3",
category="loaders",
description="Loads a batch of images from a directory for training.",
is_experimental=True,
inputs=[
io.Combo.Input(
"folder", options=folder_paths.get_input_subfolders(), tooltip="The folder to load images from."
),
io.Combo.Input(
"resize_method", options=["None", "Stretch", "Crop", "Pad"], default="None", optional=True
),
],
outputs=[
io.Image.Output(),
],
)
@classmethod
def execute(cls, folder, resize_method="None"):
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
image_files = [
f
for f in os.listdir(sub_input_dir)
if any(f.lower().endswith(ext) for ext in valid_extensions)
]
return io.NodeOutput(load_and_process_images(image_files, sub_input_dir, resize_method))
class LoadImageTextSetFromFolderNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadImageTextSetFromFolderNode_V3",
display_name="Load Image and Text Dataset from Folder _V3",
category="loaders",
description="Loads a batch of images and caption from a directory for training.",
is_experimental=True,
inputs=[
io.Combo.Input("folder", options=folder_paths.get_input_subfolders(), tooltip="The folder to load images from."),
io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."),
io.Combo.Input("resize_method", options=["None", "Stretch", "Crop", "Pad"], default="None", optional=True),
io.Int.Input("width", default=-1, min=-1, max=10000, step=1, tooltip="The width to resize the images to. -1 means use the original width.", optional=True),
io.Int.Input("height", default=-1, min=-1, max=10000, step=1, tooltip="The height to resize the images to. -1 means use the original height.", optional=True),
],
outputs=[
io.Image.Output(),
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, folder, clip, resize_method="None", width=None, height=None):
if clip is None:
raise RuntimeError(
"ERROR: clip input is invalid: None\n\n"
"If the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model."
)
logging.info(f"Loading images from folder: {folder}")
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
image_files = []
for item in os.listdir(sub_input_dir):
path = os.path.join(sub_input_dir, item)
if any(item.lower().endswith(ext) for ext in valid_extensions):
image_files.append(path)
elif os.path.isdir(path):
# Support kohya-ss/sd-scripts folder structure
repeat = 1
if item.split("_")[0].isdigit():
repeat = int(item.split("_")[0])
image_files.extend([
os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions)
] * repeat)
caption_file_path = [
f.replace(os.path.splitext(f)[1], ".txt")
for f in image_files
]
captions = []
for caption_file in caption_file_path:
caption_path = os.path.join(sub_input_dir, caption_file)
if os.path.exists(caption_path):
with open(caption_path, "r", encoding="utf-8") as f:
caption = f.read().strip()
captions.append(caption)
else:
captions.append("")
width = width if width != -1 else None
height = height if height != -1 else None
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height)
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
logging.info(f"Encoding captions from {sub_input_dir}.")
conditions = []
empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
for text in captions:
if text == "":
conditions.append(empty_cond)
tokens = clip.tokenize(text)
conditions.extend(clip.encode_from_tokens_scheduled(tokens))
logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.")
return io.NodeOutput(output_tensor, conditions)
class LoraModelLoader(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoraModelLoader_V3",
display_name="Load LoRA Model _V3",
category="loaders",
description="Load Trained LoRA weights from Train LoRA node.",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The diffusion model the LoRA will be applied to."),
io.LoraModel.Input("lora", tooltip="The LoRA model to apply to the diffusion model."),
io.Float.Input("strength_model", default=1.0, min=-100.0, max=100.0, step=0.01, tooltip="How strongly to modify the diffusion model. This value can be negative."),
],
outputs=[
io.Model.Output(tooltip="The modified diffusion model."),
],
)
@classmethod
def execute(cls, model, lora, strength_model):
if strength_model == 0:
return io.NodeOutput(model)
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
return io.NodeOutput(model_lora)
class LossGraphNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LossGraphNode_V3",
display_name="Plot Loss Graph _V3",
category="training",
description="Plots the loss graph and saves it to the output directory.",
is_experimental=True,
is_output_node=True,
inputs=[
io.LossMap.Input("loss"), # TODO: original V1 node has also `default={}` parameter
io.String.Input("filename_prefix", default="loss_graph"),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
)
@classmethod
def execute(cls, loss, filename_prefix):
loss_values = loss["loss"]
width, height = 800, 480
margin = 40
img = Image.new(
"RGB", (width + margin, height + margin), "white"
) # Extend canvas
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_values), max(loss_values)
scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_values]
steps = len(loss_values)
prev_point = (margin, height - int(scaled_loss[0] * height))
for i, l_v in enumerate(scaled_loss[1:], start=1):
x = margin + int(i / steps * width) # Scale X properly
y = height - int(l_v * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
draw.line(
[(margin, height), (width + margin, height)], fill="black", width=2
) # X-axis
try:
font = ImageFont.truetype("arial.ttf", 12)
except IOError:
font = ImageFont.load_default()
# Add axis labels
draw.text((5, height // 2), "Loss", font=font, fill="black")
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
# Add min/max loss values
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
draw.text(
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
)
return io.NodeOutput(ui=ui.PreviewImage(img, cls=cls))
class SaveLoRA(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveLoRA_V3",
display_name="Save LoRA Weights _V3",
category="loaders",
is_experimental=True,
is_output_node=True,
inputs=[
io.LoraModel.Input("lora", tooltip="The LoRA model to save. Do not use the model with LoRA layers."),
io.String.Input("prefix", default="loras/ComfyUI_trained_lora", tooltip="The prefix to use for the saved LoRA file."),
io.Int.Input("steps", tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", optional=True),
],
outputs=[],
)
@classmethod
def execute(cls, lora, prefix, steps=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
prefix, folder_paths.get_output_directory()
)
if steps is None:
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
else:
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
safetensors.torch.save_file(lora, output_checkpoint)
return io.NodeOutput()
class TrainLoraNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TrainLoraNode_V3",
display_name="Train LoRA _V3",
category="training",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to train the LoRA on."),
io.Latent.Input("latents", tooltip="The Latents to use for training, serve as dataset/input of the model."),
io.Conditioning.Input("positive", tooltip="The positive conditioning to use for training."),
io.Int.Input("batch_size", default=1, min=1, max=10000, step=1, tooltip="The batch size to use for training."),
io.Int.Input("grad_accumulation_steps", default=1, min=1, max=1024, step=1, tooltip="The number of gradient accumulation steps to use for training."),
io.Int.Input("steps", default=16, min=1, max=100000, tooltip="The number of steps to train the LoRA for."),
io.Float.Input("learning_rate", default=0.0005, min=0.0000001, max=1.0, step=0.000001, tooltip="The learning rate to use for training."),
io.Int.Input("rank", default=8, min=1, max=128, tooltip="The rank of the LoRA layers."),
io.Combo.Input("optimizer", options=["AdamW", "Adam", "SGD", "RMSprop"], default="AdamW", tooltip="The optimizer to use for training."),
io.Combo.Input("loss_function", options=["MSE", "L1", "Huber", "SmoothL1"], default="MSE", tooltip="The loss function to use for training."),
io.Int.Input("seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)"),
io.Combo.Input("training_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for training."),
io.Combo.Input("lora_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for lora."),
io.Combo.Input("algorithm", options=list(adapter_maps.keys()), default=list(adapter_maps.keys())[0], tooltip="The algorithm to use for training."),
io.Boolean.Input("gradient_checkpointing", default=True, tooltip="Use gradient checkpointing for training."),
io.Combo.Input("existing_lora", options=folder_paths.get_filename_list("loras") + ["[None]"], default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA."),
],
outputs=[
io.Model.Output(display_name="model_with_lora"),
io.LoraModel.Output(display_name="lora"),
io.LossMap.Output(display_name="loss"),
io.Int.Output(display_name="steps"),
],
)
@classmethod
def execute(
cls,
model,
latents,
positive,
batch_size,
steps,
grad_accumulation_steps,
learning_rate,
rank,
optimizer,
loss_function,
seed,
training_dtype,
lora_dtype,
algorithm,
gradient_checkpointing,
existing_lora,
):
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
latents = latents["samples"].to(dtype)
num_images = latents.shape[0]
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
positive = positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)
with torch.inference_mode(False):
lora_sd = {}
generator = torch.Generator()
generator.manual_seed(seed)
# Load existing LoRA weights if provided
existing_weights = {}
existing_steps = 0
if existing_lora != "[None]":
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
all_weight_adapters = []
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
key = "{}.weight".format(n)
shape = m.weight.shape
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(
f"{key}.dora_scale", None
)
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
n, existing_weights, alpha, dora_scale
)
if existing_adapter is not None:
break
else:
existing_adapter = None
adapter_cls = adapter_maps[algorithm]
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
m.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_sd[f"{n}.{name}"] = parameter
mp.add_weight_wrapper(key, train_adapter)
all_weight_adapters.append(train_adapter)
else:
diff = torch.nn.Parameter(
torch.zeros(
m.weight.shape, dtype=lora_dtype, requires_grad=True
)
)
diff_module = BiasDiff(diff)
mp.add_weight_wrapper(key, BiasDiff(diff))
all_weight_adapters.append(diff_module)
lora_sd["{}.diff".format(n)] = diff
if hasattr(m, "bias") and m.bias is not None:
key = "{}.bias".format(n)
bias = torch.nn.Parameter(
torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True)
)
bias_module = BiasDiff(bias)
lora_sd["{}.diff_b".format(n)] = bias
mp.add_weight_wrapper(key, BiasDiff(bias))
all_weight_adapters.append(bias_module)
if optimizer == "Adam":
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
elif optimizer == "AdamW":
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
elif optimizer == "SGD":
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
elif optimizer == "RMSprop":
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
# Setup loss function based on selection
if loss_function == "MSE":
criterion = torch.nn.MSELoss()
elif loss_function == "L1":
criterion = torch.nn.L1Loss()
elif loss_function == "Huber":
criterion = torch.nn.HuberLoss()
elif loss_function == "SmoothL1":
criterion = torch.nn.SmoothL1Loss()
# setup models
if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
patch(m)
mp.model.requires_grad_(False)
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
# Setup sampler and guider like in test script
loss_map = {"loss": []}
def loss_callback(loss):
loss_map["loss"].append(loss)
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype
)
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input
# Training loop
try:
# Generate dummy sigmas and noise
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed
)
finally:
for m in mp.model.modules():
unpatch(m)
del train_sampler, optimizer
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
NODES_LIST = [
LoadImageSetFromFolderNode,
LoadImageTextSetFromFolderNode,
LoraModelLoader,
LossGraphNode,
SaveLoRA,
TrainLoraNode,
]

View File

@ -0,0 +1,106 @@
from __future__ import annotations
import logging
import torch
from spandrel import ImageModelDescriptor, ModelLoader
import comfy.utils
import folder_paths
from comfy import model_management
from comfy_api.v3 import io
try:
from spandrel import MAIN_REGISTRY
from spandrel_extra_arches import EXTRA_REGISTRY
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
except Exception:
pass
class ImageUpscaleWithModel(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageUpscaleWithModel_V3",
display_name="Upscale Image (using Model) _V3",
category="image/upscaling",
inputs=[
io.UpscaleModel.Input("upscale_model"),
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
],
)
@classmethod
def execute(cls, upscale_model, image):
device = model_management.get_torch_device()
memory_required = model_management.module_size(upscale_model.model)
memory_required += (512 * 512 * 3) * image.element_size() * max(upscale_model.scale, 1.0) * 384.0 #The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate
memory_required += image.nelement() * image.element_size()
model_management.free_memory(memory_required, device)
upscale_model.to(device)
in_img = image.movedim(-1,-3).to(device)
tile = 512
overlap = 32
oom = True
while oom:
try:
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(
in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap
)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(
in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar
)
oom = False
except model_management.OOM_EXCEPTION as e:
tile //= 2
if tile < 128:
raise e
upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
return io.NodeOutput(s)
class UpscaleModelLoader(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="UpscaleModelLoader_V3",
display_name="Load Upscale Model _V3",
category="loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("upscale_models")),
],
outputs=[
io.UpscaleModel.Output(),
],
)
@classmethod
def execute(cls, model_name):
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
out = ModelLoader().load_from_state_dict(sd).eval()
if not isinstance(out, ImageModelDescriptor):
raise Exception("Upscale model must be a single-image model.")
return io.NodeOutput(out)
NODES_LIST = [
ImageUpscaleWithModel,
UpscaleModelLoader,
]

View File

@ -0,0 +1,232 @@
from __future__ import annotations
import torch
import comfy.sd
import comfy.utils
import comfy_extras.nodes_model_merging
import folder_paths
import node_helpers
import nodes
from comfy_api.v3 import io
class ConditioningSetAreaPercentageVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ConditioningSetAreaPercentageVideo_V3",
category="conditioning",
inputs=[
io.Conditioning.Input("conditioning"),
io.Float.Input("width", default=1.0, min=0, max=1.0, step=0.01),
io.Float.Input("height", default=1.0, min=0, max=1.0, step=0.01),
io.Float.Input("temporal", default=1.0, min=0, max=1.0, step=0.01),
io.Float.Input("x", default=0, min=0, max=1.0, step=0.01),
io.Float.Input("y", default=0, min=0, max=1.0, step=0.01),
io.Float.Input("z", default=0, min=0, max=1.0, step=0.01),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, conditioning, width, height, temporal, x, y, z, strength):
c = node_helpers.conditioning_set_values(
conditioning,
{
"area": ("percentage", temporal, height, width, z, y, x),
"strength": strength,
"set_area_to_bounds": False
,}
)
return io.NodeOutput(c)
class ImageOnlyCheckpointLoader(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageOnlyCheckpointLoader_V3",
display_name="Image Only Checkpoint Loader (img2vid model) _V3",
category="loaders/video_models",
inputs=[
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("checkpoints")),
],
outputs=[
io.Model.Output(),
io.ClipVision.Output(),
io.Vae.Output(),
],
)
@classmethod
def execute(cls, ckpt_name):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(
ckpt_path,
output_vae=True,
output_clip=False,
output_clipvision=True,
embedding_directory=folder_paths.get_folder_paths("embeddings"),
)
return io.NodeOutput(out[0], out[3], out[2])
class ImageOnlyCheckpointSave(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ImageOnlyCheckpointSave_V3",
category="advanced/model_merging",
inputs=[
io.Model.Input("model"),
io.ClipVision.Input("clip_vision"),
io.Vae.Input("vae"),
io.String.Input("filename_prefix", default="checkpoints/ComfyUI"),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
)
@classmethod
def execute(cls, model, clip_vision, vae, filename_prefix):
output_dir = folder_paths.get_output_directory()
comfy_extras.nodes_model_merging.save_checkpoint(
model,
clip_vision=clip_vision,
vae=vae,
filename_prefix=filename_prefix,
output_dir=output_dir,
prompt=cls.hidden.prompt,
extra_pnginfo=cls.hidden.extra_pnginfo,
)
return io.NodeOutput()
class SVD_img2vid_Conditioning(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SVD_img2vid_Conditioning_V3",
category="conditioning/video_models",
inputs=[
io.ClipVision.Input("clip_vision"),
io.Image.Input("init_image"),
io.Vae.Input("vae"),
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=576, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("video_frames", default=14, min=1, max=4096),
io.Int.Input("motion_bucket_id", default=127, min=1, max=1023),
io.Int.Input("fps", default=6, min=1, max=1024),
io.Float.Input("augmentation_level", default=0.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(
init_image.movedim(-1,1), width, height, "bilinear", "center"
).movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
if augmentation_level > 0:
encode_pixels += torch.randn_like(pixels) * augmentation_level
t = vae.encode(encode_pixels)
positive = [
[
pooled,
{"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t},
]
]
negative = [
[
torch.zeros_like(pooled),
{"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)},
]
]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return io.NodeOutput(positive, negative, {"samples":latent})
class VideoLinearCFGGuidance(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VideoLinearCFGGuidance_V3",
category="sampling/video_models",
inputs=[
io.Model.Input("model"),
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, min_cfg):
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
scale = torch.linspace(
min_cfg, cond_scale, cond.shape[0], device=cond.device
).reshape((cond.shape[0], 1, 1, 1))
return uncond + scale * (cond - uncond)
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return io.NodeOutput(m)
class VideoTriangleCFGGuidance(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VideoTriangleCFGGuidance_V3",
category="sampling/video_models",
inputs=[
io.Model.Input("model"),
io.Float.Input("min_cfg", default=1.0, min=0.0, max=100.0, step=0.5, round=0.01),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, min_cfg):
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
period = 1.0
values = torch.linspace(0, 1, cond.shape[0], device=cond.device)
values = 2 * (values / period - torch.floor(values / period + 0.5)).abs()
scale = (values * (cond_scale - min_cfg) + min_cfg).reshape((cond.shape[0], 1, 1, 1))
return uncond + scale * (cond - uncond)
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return io.NodeOutput(m)
NODES_LIST = [
ConditioningSetAreaPercentageVideo,
ImageOnlyCheckpointLoader,
ImageOnlyCheckpointSave,
SVD_img2vid_Conditioning,
VideoLinearCFGGuidance,
VideoTriangleCFGGuidance,
]

View File

@ -74,7 +74,8 @@ if not args.cuda_malloc:
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) spec.loader.exec_module(module)
version = module.__version__ version = module.__version__
if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
if int(version[0]) >= 2 and "+cu" in version: #enable by default for torch version 2.0 and up only on cuda torch
args.cuda_malloc = cuda_malloc_supported() args.cuda_malloc = cuda_malloc_supported()
except: except:
pass pass

View File

@ -2320,9 +2320,17 @@ def init_builtin_extra_nodes():
"v3/nodes_fresca.py", "v3/nodes_fresca.py",
"v3/nodes_gits.py", "v3/nodes_gits.py",
"v3/nodes_hidream.py", "v3/nodes_hidream.py",
"v3/nodes_hunyuan.py",
"v3/nodes_hypernetwork.py",
"v3/nodes_hypertile.py",
"v3/nodes_images.py", "v3/nodes_images.py",
"v3/nodes_ip2p.py",
"v3/nodes_latent.py", "v3/nodes_latent.py",
"v3/nodes_load_3d.py",
"v3/nodes_lora_extract.py",
"v3/nodes_lotus.py",
"v3/nodes_lt.py", "v3/nodes_lt.py",
"v3/nodes_lumina2.py",
"v3/nodes_mask.py", "v3/nodes_mask.py",
"v3/nodes_mochi.py", "v3/nodes_mochi.py",
"v3/nodes_model_advanced.py", "v3/nodes_model_advanced.py",
@ -2342,7 +2350,13 @@ def init_builtin_extra_nodes():
"v3/nodes_sdupscale.py", "v3/nodes_sdupscale.py",
"v3/nodes_slg.py", "v3/nodes_slg.py",
"v3/nodes_stable_cascade.py", "v3/nodes_stable_cascade.py",
"v3/nodes_tcfg.py",
"v3/nodes_tomesd.py",
"v3/nodes_torch_compile.py",
"v3/nodes_train.py",
"v3/nodes_upscale_model.py",
"v3/nodes_video.py", "v3/nodes_video.py",
"v3/nodes_video_model.py",
"v3/nodes_wan.py", "v3/nodes_wan.py",
"v3/nodes_webcam.py", "v3/nodes_webcam.py",
] ]

View File

@ -554,6 +554,7 @@ class PromptServer():
ram_free = comfy.model_management.get_free_memory(cpu_device) ram_free = comfy.model_management.get_free_memory(cpu_device)
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
required_frontend_version = FrontendManager.get_required_frontend_version()
system_stats = { system_stats = {
"system": { "system": {
@ -561,6 +562,7 @@ class PromptServer():
"ram_total": ram_total, "ram_total": ram_total,
"ram_free": ram_free, "ram_free": ram_free,
"comfyui_version": __version__, "comfyui_version": __version__,
"required_frontend_version": required_frontend_version,
"python_version": sys.version, "python_version": sys.version,
"pytorch_version": comfy.model_management.torch_version, "pytorch_version": comfy.model_management.torch_version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded", "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",

View File

@ -1,7 +1,7 @@
import argparse import argparse
import pytest import pytest
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from unittest.mock import patch from unittest.mock import patch, mock_open
from app.frontend_management import ( from app.frontend_management import (
FrontendManager, FrontendManager,
@ -172,3 +172,36 @@ def test_init_frontend_fallback_on_error():
# Assert # Assert
assert frontend_path == "/default/path" assert frontend_path == "/default/path"
mock_check.assert_called_once() mock_check.assert_called_once()
def test_get_frontend_version():
# Arrange
expected_version = "1.25.0"
mock_requirements_content = """torch
torchsde
comfyui-frontend-package==1.25.0
other-package==1.0.0
numpy"""
# Act
with patch("builtins.open", mock_open(read_data=mock_requirements_content)):
version = FrontendManager.get_required_frontend_version()
# Assert
assert version == expected_version
def test_get_frontend_version_invalid_semver():
# Arrange
mock_requirements_content = """torch
torchsde
comfyui-frontend-package==1.29.3.75
other-package==1.0.0
numpy"""
# Act
with patch("builtins.open", mock_open(read_data=mock_requirements_content)):
version = FrontendManager.get_required_frontend_version()
# Assert
assert version is None