mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-30 09:46:37 +00:00
converted nodes files starting with "t" letter
This commit is contained in:
parent
487ec28b9c
commit
2ea2bc2941
70
comfy_extras/v3/nodes_tcfg.py
Normal file
70
comfy_extras/v3/nodes_tcfg.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
"""TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy_api.v3 import io
|
||||||
|
|
||||||
|
|
||||||
|
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Drop tangential components from uncond score to align with cond score."""
|
||||||
|
# (B, 1, ...)
|
||||||
|
batch_num = cond_score.shape[0]
|
||||||
|
cond_score_flat = cond_score.reshape(batch_num, 1, -1).float()
|
||||||
|
uncond_score_flat = uncond_score.reshape(batch_num, 1, -1).float()
|
||||||
|
|
||||||
|
# Score matrix A (B, 2, ...)
|
||||||
|
score_matrix = torch.cat((uncond_score_flat, cond_score_flat), dim=1)
|
||||||
|
try:
|
||||||
|
_, _, Vh = torch.linalg.svd(score_matrix, full_matrices=False)
|
||||||
|
except RuntimeError:
|
||||||
|
# Fallback to CPU
|
||||||
|
_, _, Vh = torch.linalg.svd(score_matrix.cpu(), full_matrices=False)
|
||||||
|
|
||||||
|
# Drop the tangential components
|
||||||
|
v1 = Vh[:, 0:1, :].to(uncond_score_flat.device) # (B, 1, ...)
|
||||||
|
uncond_score_td = (uncond_score_flat @ v1.transpose(-2, -1)) * v1
|
||||||
|
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class TCFG(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="TCFG_V3",
|
||||||
|
display_name="Tangential Damping CFG _V3",
|
||||||
|
category="advanced/guidance",
|
||||||
|
description="TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality.",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(display_name="patched_model"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model):
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
def tangential_damping_cfg(args):
|
||||||
|
# Assume [cond, uncond, ...]
|
||||||
|
x = args["input"]
|
||||||
|
conds_out = args["conds_out"]
|
||||||
|
if len(conds_out) <= 1 or None in args["conds"][:2]:
|
||||||
|
# Skip when either cond or uncond is None
|
||||||
|
return conds_out
|
||||||
|
cond_pred = conds_out[0]
|
||||||
|
uncond_pred = conds_out[1]
|
||||||
|
uncond_td = score_tangential_damping(x - cond_pred, x - uncond_pred)
|
||||||
|
uncond_pred_td = x - uncond_td
|
||||||
|
return [cond_pred, uncond_pred_td] + conds_out[2:]
|
||||||
|
|
||||||
|
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
TCFG,
|
||||||
|
]
|
190
comfy_extras/v3/nodes_tomesd.py
Normal file
190
comfy_extras/v3/nodes_tomesd.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
"""Taken from: https://github.com/dbolya/tomesd"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy_api.v3 import io
|
||||||
|
|
||||||
|
|
||||||
|
def do_nothing(x: torch.Tensor, mode:str=None):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def mps_gather_workaround(input, dim, index):
|
||||||
|
if input.shape[-1] == 1:
|
||||||
|
return torch.gather(
|
||||||
|
input.unsqueeze(-1),
|
||||||
|
dim - 1 if dim < 0 else dim,
|
||||||
|
index.unsqueeze(-1)
|
||||||
|
).squeeze(-1)
|
||||||
|
return torch.gather(input, dim, index)
|
||||||
|
|
||||||
|
|
||||||
|
def bipartite_soft_matching_random2d(
|
||||||
|
metric: torch.Tensor,w: int, h: int, sx: int, sy: int, r: int, no_rand: bool = False
|
||||||
|
) -> Tuple[Callable, Callable]:
|
||||||
|
"""
|
||||||
|
Partitions the tokens into src and dst and merges r tokens from src to dst.
|
||||||
|
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
|
||||||
|
Args:
|
||||||
|
- metric [B, N, C]: metric to use for similarity
|
||||||
|
- w: image width in tokens
|
||||||
|
- h: image height in tokens
|
||||||
|
- sx: stride in the x dimension for dst, must divide w
|
||||||
|
- sy: stride in the y dimension for dst, must divide h
|
||||||
|
- r: number of tokens to remove (by merging)
|
||||||
|
- no_rand: if true, disable randomness (use top left corner only)
|
||||||
|
"""
|
||||||
|
B, N, _ = metric.shape
|
||||||
|
|
||||||
|
if r <= 0 or w == 1 or h == 1:
|
||||||
|
return do_nothing, do_nothing
|
||||||
|
|
||||||
|
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
hsy, wsx = h // sy, w // sx
|
||||||
|
|
||||||
|
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
||||||
|
if no_rand:
|
||||||
|
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
|
||||||
|
else:
|
||||||
|
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
|
||||||
|
|
||||||
|
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
|
||||||
|
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
|
||||||
|
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
|
||||||
|
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
|
||||||
|
|
||||||
|
# Image is not divisible by sx or sy so we need to move it into a new buffer
|
||||||
|
if (hsy * sy) < h or (wsx * sx) < w:
|
||||||
|
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
|
||||||
|
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
|
||||||
|
else:
|
||||||
|
idx_buffer = idx_buffer_view
|
||||||
|
|
||||||
|
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
|
||||||
|
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
|
||||||
|
|
||||||
|
# We're finished with these
|
||||||
|
del idx_buffer, idx_buffer_view
|
||||||
|
|
||||||
|
# rand_idx is currently dst|src, so split them
|
||||||
|
num_dst = hsy * wsx
|
||||||
|
a_idx = rand_idx[:, num_dst:, :] # src
|
||||||
|
b_idx = rand_idx[:, :num_dst, :] # dst
|
||||||
|
|
||||||
|
def split(x):
|
||||||
|
C = x.shape[-1]
|
||||||
|
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
|
||||||
|
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
|
||||||
|
return src, dst
|
||||||
|
|
||||||
|
# Cosine similarity between A and B
|
||||||
|
metric = metric / metric.norm(dim=-1, keepdim=True)
|
||||||
|
a, b = split(metric)
|
||||||
|
scores = a @ b.transpose(-1, -2)
|
||||||
|
|
||||||
|
# Can't reduce more than the # tokens in src
|
||||||
|
r = min(a.shape[1], r)
|
||||||
|
|
||||||
|
# Find the most similar greedily
|
||||||
|
node_max, node_idx = scores.max(dim=-1)
|
||||||
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
||||||
|
|
||||||
|
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
|
||||||
|
src_idx = edge_idx[..., :r, :] # Merged Tokens
|
||||||
|
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
|
||||||
|
|
||||||
|
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
||||||
|
src, dst = split(x)
|
||||||
|
n, t1, c = src.shape
|
||||||
|
|
||||||
|
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
|
||||||
|
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
|
||||||
|
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
||||||
|
|
||||||
|
return torch.cat([unm, dst], dim=1)
|
||||||
|
|
||||||
|
def unmerge(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
unm_len = unm_idx.shape[1]
|
||||||
|
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
|
||||||
|
_, _, c = unm.shape
|
||||||
|
|
||||||
|
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
|
||||||
|
|
||||||
|
# Combine back to the original shape
|
||||||
|
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
|
||||||
|
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
|
||||||
|
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
|
||||||
|
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
return merge, unmerge
|
||||||
|
|
||||||
|
|
||||||
|
def get_functions(x, ratio, original_shape):
|
||||||
|
b, c, original_h, original_w = original_shape
|
||||||
|
original_tokens = original_h * original_w
|
||||||
|
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
|
||||||
|
stride_x = 2
|
||||||
|
stride_y = 2
|
||||||
|
max_downsample = 1
|
||||||
|
|
||||||
|
if downsample <= max_downsample:
|
||||||
|
w = int(math.ceil(original_w / downsample))
|
||||||
|
h = int(math.ceil(original_h / downsample))
|
||||||
|
r = int(x.shape[1] * ratio)
|
||||||
|
no_rand = False
|
||||||
|
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
|
||||||
|
return m, u
|
||||||
|
|
||||||
|
def nothing(y):
|
||||||
|
return y
|
||||||
|
|
||||||
|
return nothing, nothing
|
||||||
|
|
||||||
|
|
||||||
|
class TomePatchModel(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="TomePatchModel_V3",
|
||||||
|
category="model_patches/unet",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, ratio):
|
||||||
|
u = None
|
||||||
|
|
||||||
|
def tomesd_m(q, k, v, extra_options):
|
||||||
|
nonlocal u
|
||||||
|
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
|
||||||
|
#however from my basic testing it seems that using q instead gives better results
|
||||||
|
m, u = get_functions(q, ratio, extra_options["original_shape"])
|
||||||
|
return m(q), k, v
|
||||||
|
|
||||||
|
def tomesd_u(n, extra_options):
|
||||||
|
return u(n)
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_attn1_patch(tomesd_m)
|
||||||
|
m.set_model_attn1_output_patch(tomesd_u)
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
TomePatchModel,
|
||||||
|
]
|
32
comfy_extras/v3/nodes_torch_compile.py
Normal file
32
comfy_extras/v3/nodes_torch_compile.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
||||||
|
from comfy_api.v3 import io
|
||||||
|
|
||||||
|
|
||||||
|
class TorchCompileModel(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="TorchCompileModel_V3",
|
||||||
|
category="_for_testing",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Combo.Input("backend", options=["inductor", "cudagraphs"]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, backend):
|
||||||
|
m = model.clone()
|
||||||
|
set_torch_compile_wrapper(model=m, backend=backend)
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
TorchCompileModel,
|
||||||
|
]
|
658
comfy_extras/v3/nodes_train.py
Normal file
658
comfy_extras/v3/nodes_train.py
Normal file
@ -0,0 +1,658 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import safetensors
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
import tqdm
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.samplers
|
||||||
|
import comfy.sd
|
||||||
|
import comfy.utils
|
||||||
|
import comfy_extras.nodes_custom_sampler
|
||||||
|
import folder_paths
|
||||||
|
import node_helpers
|
||||||
|
from comfy.weight_adapter import adapters
|
||||||
|
from comfy_api.v3 import io, ui
|
||||||
|
|
||||||
|
|
||||||
|
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||||
|
new_dict = {}
|
||||||
|
for k, v in d.items():
|
||||||
|
newv = v
|
||||||
|
if isinstance(v, dict):
|
||||||
|
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
|
||||||
|
elif isinstance(v, torch.Tensor):
|
||||||
|
if full_size is None or v.size(0) == full_size:
|
||||||
|
newv = v[indicies]
|
||||||
|
elif isinstance(v, (list, tuple)) and len(v) == full_size:
|
||||||
|
newv = [v[i] for i in indicies]
|
||||||
|
new_dict[k] = newv
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
|
class TrainSampler(comfy.samplers.Sampler):
|
||||||
|
|
||||||
|
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||||
|
self.loss_fn = loss_fn
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.loss_callback = loss_callback
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.total_steps = total_steps
|
||||||
|
self.seed = seed
|
||||||
|
self.training_dtype = training_dtype
|
||||||
|
|
||||||
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
|
cond = model_wrap.conds["positive"]
|
||||||
|
dataset_size = sigmas.size(0)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||||
|
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
|
||||||
|
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
|
||||||
|
|
||||||
|
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||||
|
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
|
||||||
|
batch_sigmas = [
|
||||||
|
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||||
|
torch.rand((1,)).item()
|
||||||
|
) for _ in range(min(self.batch_size, dataset_size))
|
||||||
|
]
|
||||||
|
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
||||||
|
|
||||||
|
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||||
|
batch_sigmas,
|
||||||
|
batch_noise,
|
||||||
|
batch_latent,
|
||||||
|
False
|
||||||
|
)
|
||||||
|
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||||
|
torch.zeros_like(batch_sigmas),
|
||||||
|
torch.zeros_like(batch_noise),
|
||||||
|
batch_latent,
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
model_wrap.conds["positive"] = [
|
||||||
|
cond[i] for i in indicies
|
||||||
|
]
|
||||||
|
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
|
||||||
|
|
||||||
|
with torch.autocast(xt.device.type, dtype=self.training_dtype):
|
||||||
|
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
|
||||||
|
loss = self.loss_fn(x0_pred, x0)
|
||||||
|
loss.backward()
|
||||||
|
if self.loss_callback:
|
||||||
|
self.loss_callback(loss.item())
|
||||||
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||||
|
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return torch.zeros_like(latent_image)
|
||||||
|
|
||||||
|
|
||||||
|
class BiasDiff(torch.nn.Module):
|
||||||
|
def __init__(self, bias):
|
||||||
|
super().__init__()
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
def __call__(self, b):
|
||||||
|
org_dtype = b.dtype
|
||||||
|
return (b.to(self.bias) + self.bias).to(org_dtype)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
return self.bias.nelement() * self.bias.element_size()
|
||||||
|
|
||||||
|
def move_to(self, device):
|
||||||
|
self.to(device=device)
|
||||||
|
return self.passive_memory_usage()
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
|
||||||
|
"""Utility function to load and process a list of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_files: List of image filenames
|
||||||
|
input_dir: Base directory containing the images
|
||||||
|
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Batch of processed images
|
||||||
|
"""
|
||||||
|
if not image_files:
|
||||||
|
raise ValueError("No valid images found in input")
|
||||||
|
|
||||||
|
output_images = []
|
||||||
|
|
||||||
|
for file in image_files:
|
||||||
|
image_path = os.path.join(input_dir, file)
|
||||||
|
img = node_helpers.pillow(Image.open, image_path)
|
||||||
|
|
||||||
|
if img.mode == "I":
|
||||||
|
img = img.point(lambda i: i * (1 / 255))
|
||||||
|
img = img.convert("RGB")
|
||||||
|
|
||||||
|
if w is None and h is None:
|
||||||
|
w, h = img.size[0], img.size[1]
|
||||||
|
|
||||||
|
# Resize image to first image
|
||||||
|
if img.size[0] != w or img.size[1] != h:
|
||||||
|
if resize_method == "Stretch":
|
||||||
|
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||||
|
elif resize_method == "Crop":
|
||||||
|
img = img.crop((0, 0, w, h))
|
||||||
|
elif resize_method == "Pad":
|
||||||
|
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||||
|
elif resize_method == "None":
|
||||||
|
raise ValueError(
|
||||||
|
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
|
||||||
|
)
|
||||||
|
|
||||||
|
img_array = np.array(img).astype(np.float32) / 255.0
|
||||||
|
img_tensor = torch.from_numpy(img_array)[None,]
|
||||||
|
output_images.append(img_tensor)
|
||||||
|
|
||||||
|
return torch.cat(output_images, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_loss_graph(loss_map, steps):
|
||||||
|
width, height = 500, 300
|
||||||
|
img = Image.new("RGB", (width, height), "white")
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
|
||||||
|
scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_map.values()]
|
||||||
|
|
||||||
|
prev_point = (0, height - int(scaled_loss[0] * height))
|
||||||
|
for i, l_v in enumerate(scaled_loss[1:], start=1):
|
||||||
|
x = int(i / (steps - 1) * width)
|
||||||
|
y = height - int(l_v * height)
|
||||||
|
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||||
|
prev_point = (x, y)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None):
|
||||||
|
if result is None:
|
||||||
|
result = []
|
||||||
|
elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
|
||||||
|
result.append(model)
|
||||||
|
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
|
||||||
|
return result
|
||||||
|
name = name or "root"
|
||||||
|
for next_name, child in model.named_children():
|
||||||
|
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def patch(m):
|
||||||
|
if not hasattr(m, "forward"):
|
||||||
|
return
|
||||||
|
org_forward = m.forward
|
||||||
|
def fwd(args, kwargs):
|
||||||
|
return org_forward(*args, **kwargs)
|
||||||
|
def checkpointing_fwd(*args, **kwargs):
|
||||||
|
return torch.utils.checkpoint.checkpoint(
|
||||||
|
fwd, args, kwargs, use_reentrant=False
|
||||||
|
)
|
||||||
|
m.org_forward = org_forward
|
||||||
|
m.forward = checkpointing_fwd
|
||||||
|
|
||||||
|
|
||||||
|
def unpatch(m):
|
||||||
|
if hasattr(m, "org_forward"):
|
||||||
|
m.forward = m.org_forward
|
||||||
|
del m.org_forward
|
||||||
|
|
||||||
|
|
||||||
|
class LoadImageSetFromFolderNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LoadImageSetFromFolderNode_V3",
|
||||||
|
display_name="Load Image Dataset from Folder _V3",
|
||||||
|
category="loaders",
|
||||||
|
description="Loads a batch of images from a directory for training.",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input(
|
||||||
|
"folder", options=folder_paths.get_input_subfolders(), tooltip="The folder to load images from."
|
||||||
|
),
|
||||||
|
io.Combo.Input(
|
||||||
|
"resize_method", options=["None", "Stretch", "Crop", "Pad"], default="None", optional=True
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, folder, resize_method="None"):
|
||||||
|
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||||
|
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
||||||
|
image_files = [
|
||||||
|
f
|
||||||
|
for f in os.listdir(sub_input_dir)
|
||||||
|
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||||
|
]
|
||||||
|
return io.NodeOutput(load_and_process_images(image_files, sub_input_dir, resize_method))
|
||||||
|
|
||||||
|
|
||||||
|
class LoadImageTextSetFromFolderNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LoadImageTextSetFromFolderNode_V3",
|
||||||
|
display_name="Load Image and Text Dataset from Folder _V3",
|
||||||
|
category="loaders",
|
||||||
|
description="Loads a batch of images and caption from a directory for training.",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("folder", options=folder_paths.get_input_subfolders(), tooltip="The folder to load images from."),
|
||||||
|
io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."),
|
||||||
|
io.Combo.Input("resize_method", options=["None", "Stretch", "Crop", "Pad"], default="None", optional=True),
|
||||||
|
io.Int.Input("width", default=-1, min=-1, max=10000, step=1, tooltip="The width to resize the images to. -1 means use the original width.", optional=True),
|
||||||
|
io.Int.Input("height", default=-1, min=-1, max=10000, step=1, tooltip="The height to resize the images to. -1 means use the original height.", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, folder, clip, resize_method="None", width=None, height=None):
|
||||||
|
if clip is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"ERROR: clip input is invalid: None\n\n"
|
||||||
|
"If the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model."
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Loading images from folder: {folder}")
|
||||||
|
|
||||||
|
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||||
|
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
||||||
|
|
||||||
|
image_files = []
|
||||||
|
for item in os.listdir(sub_input_dir):
|
||||||
|
path = os.path.join(sub_input_dir, item)
|
||||||
|
if any(item.lower().endswith(ext) for ext in valid_extensions):
|
||||||
|
image_files.append(path)
|
||||||
|
elif os.path.isdir(path):
|
||||||
|
# Support kohya-ss/sd-scripts folder structure
|
||||||
|
repeat = 1
|
||||||
|
if item.split("_")[0].isdigit():
|
||||||
|
repeat = int(item.split("_")[0])
|
||||||
|
image_files.extend([
|
||||||
|
os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||||
|
] * repeat)
|
||||||
|
|
||||||
|
caption_file_path = [
|
||||||
|
f.replace(os.path.splitext(f)[1], ".txt")
|
||||||
|
for f in image_files
|
||||||
|
]
|
||||||
|
captions = []
|
||||||
|
for caption_file in caption_file_path:
|
||||||
|
caption_path = os.path.join(sub_input_dir, caption_file)
|
||||||
|
if os.path.exists(caption_path):
|
||||||
|
with open(caption_path, "r", encoding="utf-8") as f:
|
||||||
|
caption = f.read().strip()
|
||||||
|
captions.append(caption)
|
||||||
|
else:
|
||||||
|
captions.append("")
|
||||||
|
|
||||||
|
width = width if width != -1 else None
|
||||||
|
height = height if height != -1 else None
|
||||||
|
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height)
|
||||||
|
|
||||||
|
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
|
||||||
|
|
||||||
|
logging.info(f"Encoding captions from {sub_input_dir}.")
|
||||||
|
conditions = []
|
||||||
|
empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
|
||||||
|
for text in captions:
|
||||||
|
if text == "":
|
||||||
|
conditions.append(empty_cond)
|
||||||
|
tokens = clip.tokenize(text)
|
||||||
|
conditions.extend(clip.encode_from_tokens_scheduled(tokens))
|
||||||
|
logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.")
|
||||||
|
return io.NodeOutput(output_tensor, conditions)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraModelLoader(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LoraModelLoader_V3",
|
||||||
|
display_name="Load LoRA Model _V3",
|
||||||
|
category="loaders",
|
||||||
|
description="Load Trained LoRA weights from Train LoRA node.",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model", tooltip="The diffusion model the LoRA will be applied to."),
|
||||||
|
io.LoraModel.Input("lora", tooltip="The LoRA model to apply to the diffusion model."),
|
||||||
|
io.Float.Input("strength_model", default=1.0, min=-100.0, max=100.0, step=0.01, tooltip="How strongly to modify the diffusion model. This value can be negative."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(tooltip="The modified diffusion model."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, lora, strength_model):
|
||||||
|
if strength_model == 0:
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
|
||||||
|
return io.NodeOutput(model_lora)
|
||||||
|
|
||||||
|
|
||||||
|
class LossGraphNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LossGraphNode_V3",
|
||||||
|
display_name="Plot Loss Graph _V3",
|
||||||
|
category="training",
|
||||||
|
description="Plots the loss graph and saves it to the output directory.",
|
||||||
|
is_experimental=True,
|
||||||
|
is_output_node=True,
|
||||||
|
inputs=[
|
||||||
|
io.LossMap.Input("loss"), # TODO: original V1 node has also `default={}` parameter
|
||||||
|
io.String.Input("filename_prefix", default="loss_graph"),
|
||||||
|
],
|
||||||
|
outputs=[],
|
||||||
|
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, loss, filename_prefix):
|
||||||
|
loss_values = loss["loss"]
|
||||||
|
width, height = 800, 480
|
||||||
|
margin = 40
|
||||||
|
|
||||||
|
img = Image.new(
|
||||||
|
"RGB", (width + margin, height + margin), "white"
|
||||||
|
) # Extend canvas
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
min_loss, max_loss = min(loss_values), max(loss_values)
|
||||||
|
scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_values]
|
||||||
|
|
||||||
|
steps = len(loss_values)
|
||||||
|
|
||||||
|
prev_point = (margin, height - int(scaled_loss[0] * height))
|
||||||
|
for i, l_v in enumerate(scaled_loss[1:], start=1):
|
||||||
|
x = margin + int(i / steps * width) # Scale X properly
|
||||||
|
y = height - int(l_v * height)
|
||||||
|
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||||
|
prev_point = (x, y)
|
||||||
|
|
||||||
|
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
|
||||||
|
draw.line(
|
||||||
|
[(margin, height), (width + margin, height)], fill="black", width=2
|
||||||
|
) # X-axis
|
||||||
|
|
||||||
|
try:
|
||||||
|
font = ImageFont.truetype("arial.ttf", 12)
|
||||||
|
except IOError:
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
|
||||||
|
# Add axis labels
|
||||||
|
draw.text((5, height // 2), "Loss", font=font, fill="black")
|
||||||
|
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
|
||||||
|
|
||||||
|
# Add min/max loss values
|
||||||
|
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
|
||||||
|
draw.text(
|
||||||
|
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
|
||||||
|
)
|
||||||
|
return io.NodeOutput(ui=ui.PreviewImage(img, cls=cls))
|
||||||
|
|
||||||
|
|
||||||
|
class SaveLoRA(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SaveLoRA_V3",
|
||||||
|
display_name="Save LoRA Weights _V3",
|
||||||
|
category="loaders",
|
||||||
|
is_experimental=True,
|
||||||
|
is_output_node=True,
|
||||||
|
inputs=[
|
||||||
|
io.LoraModel.Input("lora", tooltip="The LoRA model to save. Do not use the model with LoRA layers."),
|
||||||
|
io.String.Input("prefix", default="loras/ComfyUI_trained_lora", tooltip="The prefix to use for the saved LoRA file."),
|
||||||
|
io.Int.Input("steps", tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, lora, prefix, steps=None):
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||||
|
prefix, folder_paths.get_output_directory()
|
||||||
|
)
|
||||||
|
if steps is None:
|
||||||
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
|
else:
|
||||||
|
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
|
||||||
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
safetensors.torch.save_file(lora, output_checkpoint)
|
||||||
|
return io.NodeOutput()
|
||||||
|
|
||||||
|
|
||||||
|
class TrainLoraNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="TrainLoraNode_V3",
|
||||||
|
display_name="Train LoRA _V3",
|
||||||
|
category="training",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model", tooltip="The model to train the LoRA on."),
|
||||||
|
io.Latent.Input("latents", tooltip="The Latents to use for training, serve as dataset/input of the model."),
|
||||||
|
io.Conditioning.Input("positive", tooltip="The positive conditioning to use for training."),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=10000, step=1, tooltip="The batch size to use for training."),
|
||||||
|
io.Int.Input("steps", default=16, min=1, max=100000, tooltip="The number of steps to train the LoRA for."),
|
||||||
|
io.Float.Input("learning_rate", default=0.0005, min=0.0000001, max=1.0, step=0.000001, tooltip="The learning rate to use for training."),
|
||||||
|
io.Int.Input("rank", default=8, min=1, max=128, tooltip="The rank of the LoRA layers."),
|
||||||
|
io.Combo.Input("optimizer", options=["AdamW", "Adam", "SGD", "RMSprop"], default="AdamW", tooltip="The optimizer to use for training."),
|
||||||
|
io.Combo.Input("loss_function", options=["MSE", "L1", "Huber", "SmoothL1"], default="MSE", tooltip="The loss function to use for training."),
|
||||||
|
io.Int.Input("seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)"),
|
||||||
|
io.Combo.Input("training_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for training."),
|
||||||
|
io.Combo.Input("lora_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for lora."),
|
||||||
|
io.Combo.Input("existing_lora", options=folder_paths.get_filename_list("loras") + ["[None]"], default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(display_name="model_with_lora"),
|
||||||
|
io.LoraModel.Output(display_name="lora"),
|
||||||
|
io.LossMap.Output(display_name="loss"),
|
||||||
|
io.Int.Output(display_name="steps"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(
|
||||||
|
cls,
|
||||||
|
model,
|
||||||
|
latents,
|
||||||
|
positive,
|
||||||
|
batch_size,
|
||||||
|
steps,
|
||||||
|
learning_rate,
|
||||||
|
rank,
|
||||||
|
optimizer,
|
||||||
|
loss_function,
|
||||||
|
seed,
|
||||||
|
training_dtype,
|
||||||
|
lora_dtype,
|
||||||
|
existing_lora,
|
||||||
|
):
|
||||||
|
mp = model.clone()
|
||||||
|
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||||
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
|
mp.set_model_compute_dtype(dtype)
|
||||||
|
|
||||||
|
latents = latents["samples"].to(dtype)
|
||||||
|
num_images = latents.shape[0]
|
||||||
|
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||||
|
if len(positive) == 1 and num_images > 1:
|
||||||
|
positive = positive * num_images
|
||||||
|
elif len(positive) != num_images:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.inference_mode(False):
|
||||||
|
lora_sd = {}
|
||||||
|
generator = torch.Generator()
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
|
||||||
|
# Load existing LoRA weights if provided
|
||||||
|
existing_weights = {}
|
||||||
|
existing_steps = 0
|
||||||
|
if existing_lora != "[None]":
|
||||||
|
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
||||||
|
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
||||||
|
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
||||||
|
if lora_path:
|
||||||
|
existing_weights = comfy.utils.load_torch_file(lora_path)
|
||||||
|
|
||||||
|
all_weight_adapters = []
|
||||||
|
for n, m in mp.model.named_modules():
|
||||||
|
if hasattr(m, "weight_function"):
|
||||||
|
if m.weight is not None:
|
||||||
|
key = "{}.weight".format(n)
|
||||||
|
shape = m.weight.shape
|
||||||
|
if len(shape) >= 2:
|
||||||
|
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||||
|
dora_scale = existing_weights.get(
|
||||||
|
f"{key}.dora_scale", None
|
||||||
|
)
|
||||||
|
for adapter_cls in adapters:
|
||||||
|
existing_adapter = adapter_cls.load(
|
||||||
|
n, existing_weights, alpha, dora_scale
|
||||||
|
)
|
||||||
|
if existing_adapter is not None:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# If no existing adapter found, use LoRA
|
||||||
|
# We will add algo option in the future
|
||||||
|
existing_adapter = None
|
||||||
|
adapter_cls = adapters[0]
|
||||||
|
|
||||||
|
if existing_adapter is not None:
|
||||||
|
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||||
|
else:
|
||||||
|
# Use LoRA with alpha=1.0 by default
|
||||||
|
train_adapter = adapter_cls.create_train(
|
||||||
|
m.weight, rank=rank, alpha=1.0
|
||||||
|
).to(lora_dtype)
|
||||||
|
for name, parameter in train_adapter.named_parameters():
|
||||||
|
lora_sd[f"{n}.{name}"] = parameter
|
||||||
|
|
||||||
|
mp.add_weight_wrapper(key, train_adapter)
|
||||||
|
all_weight_adapters.append(train_adapter)
|
||||||
|
else:
|
||||||
|
diff = torch.nn.Parameter(
|
||||||
|
torch.zeros(
|
||||||
|
m.weight.shape, dtype=lora_dtype, requires_grad=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
diff_module = BiasDiff(diff)
|
||||||
|
mp.add_weight_wrapper(key, BiasDiff(diff))
|
||||||
|
all_weight_adapters.append(diff_module)
|
||||||
|
lora_sd["{}.diff".format(n)] = diff
|
||||||
|
if hasattr(m, "bias") and m.bias is not None:
|
||||||
|
key = "{}.bias".format(n)
|
||||||
|
bias = torch.nn.Parameter(
|
||||||
|
torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True)
|
||||||
|
)
|
||||||
|
bias_module = BiasDiff(bias)
|
||||||
|
lora_sd["{}.diff_b".format(n)] = bias
|
||||||
|
mp.add_weight_wrapper(key, BiasDiff(bias))
|
||||||
|
all_weight_adapters.append(bias_module)
|
||||||
|
|
||||||
|
if optimizer == "Adam":
|
||||||
|
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
|
||||||
|
elif optimizer == "AdamW":
|
||||||
|
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
|
||||||
|
elif optimizer == "SGD":
|
||||||
|
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
|
||||||
|
elif optimizer == "RMSprop":
|
||||||
|
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
|
||||||
|
|
||||||
|
# Setup loss function based on selection
|
||||||
|
if loss_function == "MSE":
|
||||||
|
criterion = torch.nn.MSELoss()
|
||||||
|
elif loss_function == "L1":
|
||||||
|
criterion = torch.nn.L1Loss()
|
||||||
|
elif loss_function == "Huber":
|
||||||
|
criterion = torch.nn.HuberLoss()
|
||||||
|
elif loss_function == "SmoothL1":
|
||||||
|
criterion = torch.nn.SmoothL1Loss()
|
||||||
|
|
||||||
|
# setup models
|
||||||
|
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||||
|
patch(m)
|
||||||
|
mp.model.requires_grad_(False)
|
||||||
|
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
||||||
|
|
||||||
|
# Setup sampler and guider like in test script
|
||||||
|
loss_map = {"loss": []}
|
||||||
|
def loss_callback(loss):
|
||||||
|
loss_map["loss"].append(loss)
|
||||||
|
train_sampler = TrainSampler(
|
||||||
|
criterion,
|
||||||
|
optimizer,
|
||||||
|
loss_callback=loss_callback,
|
||||||
|
batch_size=batch_size,
|
||||||
|
total_steps=steps,
|
||||||
|
seed=seed,
|
||||||
|
training_dtype=dtype
|
||||||
|
)
|
||||||
|
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||||
|
guider.set_conds(positive) # Set conditioning from input
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
try:
|
||||||
|
# Generate dummy sigmas and noise
|
||||||
|
sigmas = torch.tensor(range(num_images))
|
||||||
|
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||||
|
guider.sample(
|
||||||
|
noise.generate_noise({"samples": latents}),
|
||||||
|
latents,
|
||||||
|
train_sampler,
|
||||||
|
sigmas,
|
||||||
|
seed=noise.seed
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
for m in mp.model.modules():
|
||||||
|
unpatch(m)
|
||||||
|
del train_sampler, optimizer
|
||||||
|
|
||||||
|
for adapter in all_weight_adapters:
|
||||||
|
adapter.requires_grad_(False)
|
||||||
|
|
||||||
|
for param in lora_sd:
|
||||||
|
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
||||||
|
|
||||||
|
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
LoadImageSetFromFolderNode,
|
||||||
|
LoadImageTextSetFromFolderNode,
|
||||||
|
LoraModelLoader,
|
||||||
|
LossGraphNode,
|
||||||
|
SaveLoRA,
|
||||||
|
TrainLoraNode,
|
||||||
|
]
|
4
nodes.py
4
nodes.py
@ -2350,6 +2350,10 @@ def init_builtin_extra_nodes():
|
|||||||
"v3/nodes_sdupscale.py",
|
"v3/nodes_sdupscale.py",
|
||||||
"v3/nodes_slg.py",
|
"v3/nodes_slg.py",
|
||||||
"v3/nodes_stable_cascade.py",
|
"v3/nodes_stable_cascade.py",
|
||||||
|
"v3/nodes_tcfg.py",
|
||||||
|
"v3/nodes_tomesd.py",
|
||||||
|
"v3/nodes_torch_compile.py",
|
||||||
|
"v3/nodes_train.py",
|
||||||
"v3/nodes_upscale_model.py",
|
"v3/nodes_upscale_model.py",
|
||||||
"v3/nodes_video.py",
|
"v3/nodes_video.py",
|
||||||
"v3/nodes_video_model.py",
|
"v3/nodes_video_model.py",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user