converted nodes files starting with "t" letter

This commit is contained in:
bigcat88 2025-07-24 15:22:35 +03:00
parent 487ec28b9c
commit 2ea2bc2941
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
5 changed files with 954 additions and 0 deletions

View File

@ -0,0 +1,70 @@
"""TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)"""
from __future__ import annotations
import torch
from comfy_api.v3 import io
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
"""Drop tangential components from uncond score to align with cond score."""
# (B, 1, ...)
batch_num = cond_score.shape[0]
cond_score_flat = cond_score.reshape(batch_num, 1, -1).float()
uncond_score_flat = uncond_score.reshape(batch_num, 1, -1).float()
# Score matrix A (B, 2, ...)
score_matrix = torch.cat((uncond_score_flat, cond_score_flat), dim=1)
try:
_, _, Vh = torch.linalg.svd(score_matrix, full_matrices=False)
except RuntimeError:
# Fallback to CPU
_, _, Vh = torch.linalg.svd(score_matrix.cpu(), full_matrices=False)
# Drop the tangential components
v1 = Vh[:, 0:1, :].to(uncond_score_flat.device) # (B, 1, ...)
uncond_score_td = (uncond_score_flat @ v1.transpose(-2, -1)) * v1
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
class TCFG(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TCFG_V3",
display_name="Tangential Damping CFG _V3",
category="advanced/guidance",
description="TCFG Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality.",
inputs=[
io.Model.Input("model"),
],
outputs=[
io.Model.Output(display_name="patched_model"),
],
)
@classmethod
def execute(cls, model):
m = model.clone()
def tangential_damping_cfg(args):
# Assume [cond, uncond, ...]
x = args["input"]
conds_out = args["conds_out"]
if len(conds_out) <= 1 or None in args["conds"][:2]:
# Skip when either cond or uncond is None
return conds_out
cond_pred = conds_out[0]
uncond_pred = conds_out[1]
uncond_td = score_tangential_damping(x - cond_pred, x - uncond_pred)
uncond_pred_td = x - uncond_td
return [cond_pred, uncond_pred_td] + conds_out[2:]
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
return io.NodeOutput(m)
NODES_LIST = [
TCFG,
]

View File

@ -0,0 +1,190 @@
"""Taken from: https://github.com/dbolya/tomesd"""
from __future__ import annotations
import math
from typing import Callable, Tuple
import torch
from comfy_api.v3 import io
def do_nothing(x: torch.Tensor, mode:str=None):
return x
def mps_gather_workaround(input, dim, index):
if input.shape[-1] == 1:
return torch.gather(
input.unsqueeze(-1),
dim - 1 if dim < 0 else dim,
index.unsqueeze(-1)
).squeeze(-1)
return torch.gather(input, dim, index)
def bipartite_soft_matching_random2d(
metric: torch.Tensor,w: int, h: int, sx: int, sy: int, r: int, no_rand: bool = False
) -> Tuple[Callable, Callable]:
"""
Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Args:
- metric [B, N, C]: metric to use for similarity
- w: image width in tokens
- h: image height in tokens
- sx: stride in the x dimension for dst, must divide w
- sy: stride in the y dimension for dst, must divide h
- r: number of tokens to remove (by merging)
- no_rand: if true, disable randomness (use top left corner only)
"""
B, N, _ = metric.shape
if r <= 0 or w == 1 or h == 1:
return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
with torch.no_grad():
hsy, wsx = h // sy, w // sx
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
if no_rand:
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
else:
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
# Image is not divisible by sx or sy so we need to move it into a new buffer
if (hsy * sy) < h or (wsx * sx) < w:
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
else:
idx_buffer = idx_buffer_view
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
# We're finished with these
del idx_buffer, idx_buffer_view
# rand_idx is currently dst|src, so split them
num_dst = hsy * wsx
a_idx = rand_idx[:, num_dst:, :] # src
b_idx = rand_idx[:, :num_dst, :] # dst
def split(x):
C = x.shape[-1]
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
return src, dst
# Cosine similarity between A and B
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = split(metric)
scores = a @ b.transpose(-1, -2)
# Can't reduce more than the # tokens in src
r = min(a.shape[1], r)
# Find the most similar greedily
node_max, node_idx = scores.max(dim=-1)
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
src_idx = edge_idx[..., :r, :] # Merged Tokens
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x)
n, t1, c = src.shape
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
return torch.cat([unm, dst], dim=1)
def unmerge(x: torch.Tensor) -> torch.Tensor:
unm_len = unm_idx.shape[1]
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
_, _, c = unm.shape
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
# Combine back to the original shape
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
return out
return merge, unmerge
def get_functions(x, ratio, original_shape):
b, c, original_h, original_w = original_shape
original_tokens = original_h * original_w
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
stride_x = 2
stride_y = 2
max_downsample = 1
if downsample <= max_downsample:
w = int(math.ceil(original_w / downsample))
h = int(math.ceil(original_h / downsample))
r = int(x.shape[1] * ratio)
no_rand = False
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
return m, u
def nothing(y):
return y
return nothing, nothing
class TomePatchModel(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TomePatchModel_V3",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, ratio):
u = None
def tomesd_m(q, k, v, extra_options):
nonlocal u
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
#however from my basic testing it seems that using q instead gives better results
m, u = get_functions(q, ratio, extra_options["original_shape"])
return m(q), k, v
def tomesd_u(n, extra_options):
return u(n)
m = model.clone()
m.set_model_attn1_patch(tomesd_m)
m.set_model_attn1_output_patch(tomesd_u)
return io.NodeOutput(m)
NODES_LIST = [
TomePatchModel,
]

View File

@ -0,0 +1,32 @@
from __future__ import annotations
from comfy_api.torch_helpers import set_torch_compile_wrapper
from comfy_api.v3 import io
class TorchCompileModel(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TorchCompileModel_V3",
category="_for_testing",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.Combo.Input("backend", options=["inductor", "cudagraphs"]),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, backend):
m = model.clone()
set_torch_compile_wrapper(model=m, backend=backend)
return io.NodeOutput(m)
NODES_LIST = [
TorchCompileModel,
]

View File

@ -0,0 +1,658 @@
from __future__ import annotations
import logging
import os
import numpy as np
import safetensors
import torch
import torch.utils.checkpoint
import tqdm
from PIL import Image, ImageDraw, ImageFont
import comfy.model_management
import comfy.samplers
import comfy.sd
import comfy.utils
import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
from comfy.weight_adapter import adapters
from comfy_api.v3 import io, ui
def make_batch_extra_option_dict(d, indicies, full_size=None):
new_dict = {}
for k, v in d.items():
newv = v
if isinstance(v, dict):
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
elif isinstance(v, torch.Tensor):
if full_size is None or v.size(0) == full_size:
newv = v[indicies]
elif isinstance(v, (list, tuple)) and len(v) == full_size:
newv = [v[i] for i in indicies]
new_dict[k] = newv
return new_dict
class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
self.loss_fn = loss_fn
self.optimizer = optimizer
self.loss_callback = loss_callback
self.batch_size = batch_size
self.total_steps = total_steps
self.seed = seed
self.training_dtype = training_dtype
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0)
torch.cuda.empty_cache()
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
) for _ in range(min(self.batch_size, dataset_size))
]
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
xt = model_wrap.inner_model.model_sampling.noise_scaling(
batch_sigmas,
batch_noise,
batch_latent,
False
)
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(batch_sigmas),
torch.zeros_like(batch_noise),
batch_latent,
False
)
model_wrap.conds["positive"] = [
cond[i] for i in indicies
]
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
with torch.autocast(xt.device.type, dtype=self.training_dtype):
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
loss = self.loss_fn(x0_pred, x0)
loss.backward()
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
self.optimizer.step()
self.optimizer.zero_grad()
torch.cuda.empty_cache()
return torch.zeros_like(latent_image)
class BiasDiff(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.bias = bias
def __call__(self, b):
org_dtype = b.dtype
return (b.to(self.bias) + self.bias).to(org_dtype)
def passive_memory_usage(self):
return self.bias.nelement() * self.bias.element_size()
def move_to(self, device):
self.to(device=device)
return self.passive_memory_usage()
def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
"""Utility function to load and process a list of images.
Args:
image_files: List of image filenames
input_dir: Base directory containing the images
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
Returns:
torch.Tensor: Batch of processed images
"""
if not image_files:
raise ValueError("No valid images found in input")
output_images = []
for file in image_files:
image_path = os.path.join(input_dir, file)
img = node_helpers.pillow(Image.open, image_path)
if img.mode == "I":
img = img.point(lambda i: i * (1 / 255))
img = img.convert("RGB")
if w is None and h is None:
w, h = img.size[0], img.size[1]
# Resize image to first image
if img.size[0] != w or img.size[1] != h:
if resize_method == "Stretch":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "Crop":
img = img.crop((0, 0, w, h))
elif resize_method == "Pad":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "None":
raise ValueError(
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
)
img_array = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array)[None,]
output_images.append(img_tensor)
return torch.cat(output_images, dim=0)
def draw_loss_graph(loss_map, steps):
width, height = 500, 300
img = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_map.values()]
prev_point = (0, height - int(scaled_loss[0] * height))
for i, l_v in enumerate(scaled_loss[1:], start=1):
x = int(i / (steps - 1) * width)
y = height - int(l_v * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
return img
def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None):
if result is None:
result = []
elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
result.append(model)
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
return result
name = name or "root"
for next_name, child in model.named_children():
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
return result
def patch(m):
if not hasattr(m, "forward"):
return
org_forward = m.forward
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(
fwd, args, kwargs, use_reentrant=False
)
m.org_forward = org_forward
m.forward = checkpointing_fwd
def unpatch(m):
if hasattr(m, "org_forward"):
m.forward = m.org_forward
del m.org_forward
class LoadImageSetFromFolderNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadImageSetFromFolderNode_V3",
display_name="Load Image Dataset from Folder _V3",
category="loaders",
description="Loads a batch of images from a directory for training.",
is_experimental=True,
inputs=[
io.Combo.Input(
"folder", options=folder_paths.get_input_subfolders(), tooltip="The folder to load images from."
),
io.Combo.Input(
"resize_method", options=["None", "Stretch", "Crop", "Pad"], default="None", optional=True
),
],
outputs=[
io.Image.Output(),
],
)
@classmethod
def execute(cls, folder, resize_method="None"):
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
image_files = [
f
for f in os.listdir(sub_input_dir)
if any(f.lower().endswith(ext) for ext in valid_extensions)
]
return io.NodeOutput(load_and_process_images(image_files, sub_input_dir, resize_method))
class LoadImageTextSetFromFolderNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadImageTextSetFromFolderNode_V3",
display_name="Load Image and Text Dataset from Folder _V3",
category="loaders",
description="Loads a batch of images and caption from a directory for training.",
is_experimental=True,
inputs=[
io.Combo.Input("folder", options=folder_paths.get_input_subfolders(), tooltip="The folder to load images from."),
io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."),
io.Combo.Input("resize_method", options=["None", "Stretch", "Crop", "Pad"], default="None", optional=True),
io.Int.Input("width", default=-1, min=-1, max=10000, step=1, tooltip="The width to resize the images to. -1 means use the original width.", optional=True),
io.Int.Input("height", default=-1, min=-1, max=10000, step=1, tooltip="The height to resize the images to. -1 means use the original height.", optional=True),
],
outputs=[
io.Image.Output(),
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, folder, clip, resize_method="None", width=None, height=None):
if clip is None:
raise RuntimeError(
"ERROR: clip input is invalid: None\n\n"
"If the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model."
)
logging.info(f"Loading images from folder: {folder}")
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
image_files = []
for item in os.listdir(sub_input_dir):
path = os.path.join(sub_input_dir, item)
if any(item.lower().endswith(ext) for ext in valid_extensions):
image_files.append(path)
elif os.path.isdir(path):
# Support kohya-ss/sd-scripts folder structure
repeat = 1
if item.split("_")[0].isdigit():
repeat = int(item.split("_")[0])
image_files.extend([
os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions)
] * repeat)
caption_file_path = [
f.replace(os.path.splitext(f)[1], ".txt")
for f in image_files
]
captions = []
for caption_file in caption_file_path:
caption_path = os.path.join(sub_input_dir, caption_file)
if os.path.exists(caption_path):
with open(caption_path, "r", encoding="utf-8") as f:
caption = f.read().strip()
captions.append(caption)
else:
captions.append("")
width = width if width != -1 else None
height = height if height != -1 else None
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height)
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
logging.info(f"Encoding captions from {sub_input_dir}.")
conditions = []
empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
for text in captions:
if text == "":
conditions.append(empty_cond)
tokens = clip.tokenize(text)
conditions.extend(clip.encode_from_tokens_scheduled(tokens))
logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.")
return io.NodeOutput(output_tensor, conditions)
class LoraModelLoader(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoraModelLoader_V3",
display_name="Load LoRA Model _V3",
category="loaders",
description="Load Trained LoRA weights from Train LoRA node.",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The diffusion model the LoRA will be applied to."),
io.LoraModel.Input("lora", tooltip="The LoRA model to apply to the diffusion model."),
io.Float.Input("strength_model", default=1.0, min=-100.0, max=100.0, step=0.01, tooltip="How strongly to modify the diffusion model. This value can be negative."),
],
outputs=[
io.Model.Output(tooltip="The modified diffusion model."),
],
)
@classmethod
def execute(cls, model, lora, strength_model):
if strength_model == 0:
return io.NodeOutput(model)
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
return io.NodeOutput(model_lora)
class LossGraphNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LossGraphNode_V3",
display_name="Plot Loss Graph _V3",
category="training",
description="Plots the loss graph and saves it to the output directory.",
is_experimental=True,
is_output_node=True,
inputs=[
io.LossMap.Input("loss"), # TODO: original V1 node has also `default={}` parameter
io.String.Input("filename_prefix", default="loss_graph"),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
)
@classmethod
def execute(cls, loss, filename_prefix):
loss_values = loss["loss"]
width, height = 800, 480
margin = 40
img = Image.new(
"RGB", (width + margin, height + margin), "white"
) # Extend canvas
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_values), max(loss_values)
scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_values]
steps = len(loss_values)
prev_point = (margin, height - int(scaled_loss[0] * height))
for i, l_v in enumerate(scaled_loss[1:], start=1):
x = margin + int(i / steps * width) # Scale X properly
y = height - int(l_v * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
draw.line(
[(margin, height), (width + margin, height)], fill="black", width=2
) # X-axis
try:
font = ImageFont.truetype("arial.ttf", 12)
except IOError:
font = ImageFont.load_default()
# Add axis labels
draw.text((5, height // 2), "Loss", font=font, fill="black")
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
# Add min/max loss values
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
draw.text(
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
)
return io.NodeOutput(ui=ui.PreviewImage(img, cls=cls))
class SaveLoRA(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveLoRA_V3",
display_name="Save LoRA Weights _V3",
category="loaders",
is_experimental=True,
is_output_node=True,
inputs=[
io.LoraModel.Input("lora", tooltip="The LoRA model to save. Do not use the model with LoRA layers."),
io.String.Input("prefix", default="loras/ComfyUI_trained_lora", tooltip="The prefix to use for the saved LoRA file."),
io.Int.Input("steps", tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", optional=True),
],
outputs=[],
)
@classmethod
def execute(cls, lora, prefix, steps=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
prefix, folder_paths.get_output_directory()
)
if steps is None:
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
else:
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
safetensors.torch.save_file(lora, output_checkpoint)
return io.NodeOutput()
class TrainLoraNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TrainLoraNode_V3",
display_name="Train LoRA _V3",
category="training",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to train the LoRA on."),
io.Latent.Input("latents", tooltip="The Latents to use for training, serve as dataset/input of the model."),
io.Conditioning.Input("positive", tooltip="The positive conditioning to use for training."),
io.Int.Input("batch_size", default=1, min=1, max=10000, step=1, tooltip="The batch size to use for training."),
io.Int.Input("steps", default=16, min=1, max=100000, tooltip="The number of steps to train the LoRA for."),
io.Float.Input("learning_rate", default=0.0005, min=0.0000001, max=1.0, step=0.000001, tooltip="The learning rate to use for training."),
io.Int.Input("rank", default=8, min=1, max=128, tooltip="The rank of the LoRA layers."),
io.Combo.Input("optimizer", options=["AdamW", "Adam", "SGD", "RMSprop"], default="AdamW", tooltip="The optimizer to use for training."),
io.Combo.Input("loss_function", options=["MSE", "L1", "Huber", "SmoothL1"], default="MSE", tooltip="The loss function to use for training."),
io.Int.Input("seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)"),
io.Combo.Input("training_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for training."),
io.Combo.Input("lora_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for lora."),
io.Combo.Input("existing_lora", options=folder_paths.get_filename_list("loras") + ["[None]"], default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA."),
],
outputs=[
io.Model.Output(display_name="model_with_lora"),
io.LoraModel.Output(display_name="lora"),
io.LossMap.Output(display_name="loss"),
io.Int.Output(display_name="steps"),
],
)
@classmethod
def execute(
cls,
model,
latents,
positive,
batch_size,
steps,
learning_rate,
rank,
optimizer,
loss_function,
seed,
training_dtype,
lora_dtype,
existing_lora,
):
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
latents = latents["samples"].to(dtype)
num_images = latents.shape[0]
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
positive = positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)
with torch.inference_mode(False):
lora_sd = {}
generator = torch.Generator()
generator.manual_seed(seed)
# Load existing LoRA weights if provided
existing_weights = {}
existing_steps = 0
if existing_lora != "[None]":
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
all_weight_adapters = []
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
key = "{}.weight".format(n)
shape = m.weight.shape
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(
f"{key}.dora_scale", None
)
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
n, existing_weights, alpha, dora_scale
)
if existing_adapter is not None:
break
else:
# If no existing adapter found, use LoRA
# We will add algo option in the future
existing_adapter = None
adapter_cls = adapters[0]
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
m.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_sd[f"{n}.{name}"] = parameter
mp.add_weight_wrapper(key, train_adapter)
all_weight_adapters.append(train_adapter)
else:
diff = torch.nn.Parameter(
torch.zeros(
m.weight.shape, dtype=lora_dtype, requires_grad=True
)
)
diff_module = BiasDiff(diff)
mp.add_weight_wrapper(key, BiasDiff(diff))
all_weight_adapters.append(diff_module)
lora_sd["{}.diff".format(n)] = diff
if hasattr(m, "bias") and m.bias is not None:
key = "{}.bias".format(n)
bias = torch.nn.Parameter(
torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True)
)
bias_module = BiasDiff(bias)
lora_sd["{}.diff_b".format(n)] = bias
mp.add_weight_wrapper(key, BiasDiff(bias))
all_weight_adapters.append(bias_module)
if optimizer == "Adam":
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
elif optimizer == "AdamW":
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
elif optimizer == "SGD":
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
elif optimizer == "RMSprop":
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
# Setup loss function based on selection
if loss_function == "MSE":
criterion = torch.nn.MSELoss()
elif loss_function == "L1":
criterion = torch.nn.L1Loss()
elif loss_function == "Huber":
criterion = torch.nn.HuberLoss()
elif loss_function == "SmoothL1":
criterion = torch.nn.SmoothL1Loss()
# setup models
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
patch(m)
mp.model.requires_grad_(False)
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
# Setup sampler and guider like in test script
loss_map = {"loss": []}
def loss_callback(loss):
loss_map["loss"].append(loss)
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
total_steps=steps,
seed=seed,
training_dtype=dtype
)
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input
# Training loop
try:
# Generate dummy sigmas and noise
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed
)
finally:
for m in mp.model.modules():
unpatch(m)
del train_sampler, optimizer
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
NODES_LIST = [
LoadImageSetFromFolderNode,
LoadImageTextSetFromFolderNode,
LoraModelLoader,
LossGraphNode,
SaveLoRA,
TrainLoraNode,
]

View File

@ -2350,6 +2350,10 @@ def init_builtin_extra_nodes():
"v3/nodes_sdupscale.py",
"v3/nodes_slg.py",
"v3/nodes_stable_cascade.py",
"v3/nodes_tcfg.py",
"v3/nodes_tomesd.py",
"v3/nodes_torch_compile.py",
"v3/nodes_train.py",
"v3/nodes_upscale_model.py",
"v3/nodes_video.py",
"v3/nodes_video_model.py",