LoRA Trainer: LoRA training node in weight adapter scheme (#8446)

This commit is contained in:
Kohaku-Blueleaf 2025-06-14 07:25:59 +08:00 committed by GitHub
parent 5bf69bde35
commit 520eb77b72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 949 additions and 24 deletions

View File

@ -37,6 +37,8 @@ class IO(StrEnum):
CONTROL_NET = "CONTROL_NET"
VAE = "VAE"
MODEL = "MODEL"
LORA_MODEL = "LORA_MODEL"
LOSS_MAP = "LOSS_MAP"
CLIP_VISION = "CLIP_VISION"
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
STYLE_MODEL = "STYLE_MODEL"

View File

@ -753,7 +753,7 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
n = p(n, extra_options)
x += n
x = n + x
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
@ -793,12 +793,12 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
n = p(n, extra_options)
x += n
x = n + x
if self.is_res:
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
x = x_skip + x
return x

View File

@ -17,23 +17,26 @@
"""
from __future__ import annotations
from typing import Optional, Callable
import torch
import collections
import copy
import inspect
import logging
import uuid
import collections
import math
import uuid
from typing import Callable, Optional
import torch
import comfy.utils
import comfy.float
import comfy.model_management
import comfy.lora
import comfy.hooks
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
def string_to_seed(data):
crc = 0xFFFFFFFF

View File

@ -1081,7 +1081,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
def load_diffusion_model_state_dict(sd, model_options={}):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
Args:
sd (dict): State dictionary containing model weights and configuration
model_options (dict, optional): Additional options for model loading. Supports:
- dtype: Override model data type
- custom_operations: Custom model operations
- fp8_optimizations: Enable FP8 optimizations
Returns:
ModelPatcher: A wrapped model instance that handles device management and weight loading.
Returns None if the model configuration cannot be detected.
The function:
1. Detects and handles different model formats (regular, diffusers, mmdit)
2. Configures model dtype based on parameters and device capabilities
3. Handles weight conversion and device placement
4. Manages model optimization settings
5. Loads weights and returns a device-managed model instance
"""
dtype = model_options.get("dtype", None)
#Allow loading unets from checkpoint files

View File

@ -1,4 +1,4 @@
from .base import WeightAdapterBase
from .base import WeightAdapterBase, WeightAdapterTrainBase
from .lora import LoRAAdapter
from .loha import LoHaAdapter
from .lokr import LoKrAdapter
@ -15,3 +15,9 @@ adapters: list[type[WeightAdapterBase]] = [
OFTAdapter,
BOFTAdapter,
]
__all__ = [
"WeightAdapterBase",
"WeightAdapterTrainBase",
"adapters"
] + [a.__name__ for a in adapters]

View File

@ -12,12 +12,20 @@ class WeightAdapterBase:
weights: list[torch.Tensor]
@classmethod
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]:
raise NotImplementedError
def to_train(self) -> "WeightAdapterTrainBase":
raise NotImplementedError
@classmethod
def create_train(cls, weight, *args) -> "WeightAdapterTrainBase":
"""
weight: The original weight tensor to be modified.
*args: Additional arguments for configuration, such as rank, alpha etc.
"""
raise NotImplementedError
def calculate_weight(
self,
weight,
@ -33,10 +41,22 @@ class WeightAdapterBase:
class WeightAdapterTrainBase(nn.Module):
# We follow the scheme of PR #7032
def __init__(self):
super().__init__()
# [TODO] Collaborate with LoRA training PR #7032
def __call__(self, w):
"""
w: The original weight tensor to be modified.
"""
raise NotImplementedError
def passive_memory_usage(self):
raise NotImplementedError("passive_memory_usage is not implemented")
def move_to(self, device):
self.to(device)
return self.passive_memory_usage()
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
@ -102,3 +122,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
padded_tensor[new_slices] = tensor[orig_slices]
return padded_tensor
def tucker_weight_from_conv(up, down, mid):
up = up.reshape(up.size(0), up.size(1))
down = down.reshape(down.size(0), down.size(1))
return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down)
def tucker_weight(wa, wb, t):
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
return torch.einsum("i j ..., i r -> r j ...", temp, wa)

View File

@ -3,7 +3,56 @@ from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
from .base import (
WeightAdapterBase,
WeightAdapterTrainBase,
weight_decompose,
pad_tensor_to_shape,
tucker_weight_from_conv,
)
class LoraDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
mat1, mat2, alpha, mid, dora_scale, reshape = weights
out_dim, rank = mat1.shape[0], mat1.shape[1]
rank, in_dim = mat2.shape[0], mat2.shape[1]
if mid is not None:
convdim = mid.ndim - 2
layer = (
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d
)[convdim]
else:
layer = torch.nn.Linear
self.lora_up = layer(rank, out_dim, bias=False)
self.lora_down = layer(in_dim, rank, bias=False)
self.lora_up.weight.data.copy_(mat1)
self.lora_down.weight.data.copy_(mat2)
if mid is not None:
self.lora_mid = layer(mid, rank, bias=False)
self.lora_mid.weight.data.copy_(mid)
else:
self.lora_mid = None
self.rank = rank
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
def __call__(self, w):
org_dtype = w.dtype
if self.lora_mid is None:
diff = self.lora_up.weight @ self.lora_down.weight
else:
diff = tucker_weight_from_conv(
self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight
)
scale = self.alpha / self.rank
weight = w + scale * diff.reshape(w.shape)
return weight.to(org_dtype)
def passive_memory_usage(self):
return sum(param.numel() * param.element_size() for param in self.parameters())
class LoRAAdapter(WeightAdapterBase):
@ -13,6 +62,21 @@ class LoRAAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys
self.weights = weights
@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
torch.nn.init.constant_(mat2, 0.0)
return LoraDiff(
(mat1, mat2, alpha, None, None, None)
)
def to_train(self):
return LoraDiff(self.weights)
@classmethod
def load(
cls,

705
comfy_extras/nodes_train.py Normal file
View File

@ -0,0 +1,705 @@
import datetime
import json
import logging
import os
import numpy as np
import safetensors
import torch
from PIL import Image, ImageDraw, ImageFont
from PIL.PngImagePlugin import PngInfo
import torch.utils.checkpoint
import tqdm
import comfy.samplers
import comfy.sd
import comfy.utils
import comfy.model_management
import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
from comfy.cli_args import args
from comfy.comfy_types.node_typing import IO
from comfy.weight_adapter import adapters
class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None):
self.loss_fn = loss_fn
self.optimizer = optimizer
self.loss_callback = loss_callback
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
self.optimizer.zero_grad()
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False)
latent = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(sigmas),
torch.zeros_like(noise, requires_grad=True),
latent_image,
False
)
# Ensure model is in training mode and computing gradients
# x0 pred
denoised = model_wrap(noise, sigmas, **extra_args)
try:
loss = self.loss_fn(denoised, latent.clone())
except RuntimeError as e:
if "does not require grad and does not have a grad_fn" in str(e):
logging.info("WARNING: This is likely due to the model is loaded in inference mode.")
loss.backward()
if self.loss_callback:
self.loss_callback(loss.item())
self.optimizer.step()
# torch.cuda.memory._dump_snapshot("trainn.pickle")
# torch.cuda.memory._record_memory_history(enabled=None)
return torch.zeros_like(latent_image)
class BiasDiff(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.bias = bias
def __call__(self, b):
org_dtype = b.dtype
return (b.to(self.bias) + self.bias).to(org_dtype)
def passive_memory_usage(self):
return self.bias.nelement() * self.bias.element_size()
def move_to(self, device):
self.to(device=device)
return self.passive_memory_usage()
def load_and_process_images(image_files, input_dir, resize_method="None"):
"""Utility function to load and process a list of images.
Args:
image_files: List of image filenames
input_dir: Base directory containing the images
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
Returns:
torch.Tensor: Batch of processed images
"""
if not image_files:
raise ValueError("No valid images found in input")
output_images = []
w, h = None, None
for file in image_files:
image_path = os.path.join(input_dir, file)
img = node_helpers.pillow(Image.open, image_path)
if img.mode == "I":
img = img.point(lambda i: i * (1 / 255))
img = img.convert("RGB")
if w is None and h is None:
w, h = img.size[0], img.size[1]
# Resize image to first image
if img.size[0] != w or img.size[1] != h:
if resize_method == "Stretch":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "Crop":
img = img.crop((0, 0, w, h))
elif resize_method == "Pad":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "None":
raise ValueError(
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
)
img_array = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array)[None,]
output_images.append(img_tensor)
return torch.cat(output_images, dim=0)
class LoadImageSetNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": (
[
f
for f in os.listdir(folder_paths.get_input_directory())
if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"))
],
{"image_upload": True, "allow_batch": True},
)
},
"optional": {
"resize_method": (
["None", "Stretch", "Crop", "Pad"],
{"default": "None"},
),
},
}
INPUT_IS_LIST = True
RETURN_TYPES = ("IMAGE",)
FUNCTION = "load_images"
CATEGORY = "loaders"
EXPERIMENTAL = True
DESCRIPTION = "Loads a batch of images from a directory for training."
@classmethod
def VALIDATE_INPUTS(s, images, resize_method):
filenames = images[0] if isinstance(images[0], list) else images
for image in filenames:
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True
def load_images(self, input_files, resize_method):
input_dir = folder_paths.get_input_directory()
valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"]
image_files = [
f
for f in input_files
if any(f.lower().endswith(ext) for ext in valid_extensions)
]
output_tensor = load_and_process_images(image_files, input_dir, resize_method)
return (output_tensor,)
class LoadImageSetFromFolderNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."})
},
"optional": {
"resize_method": (
["None", "Stretch", "Crop", "Pad"],
{"default": "None"},
),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "load_images"
CATEGORY = "loaders"
EXPERIMENTAL = True
DESCRIPTION = "Loads a batch of images from a directory for training."
def load_images(self, folder, resize_method):
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
image_files = [
f
for f in os.listdir(sub_input_dir)
if any(f.lower().endswith(ext) for ext in valid_extensions)
]
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method)
return (output_tensor,)
def draw_loss_graph(loss_map, steps):
width, height = 500, 300
img = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()]
prev_point = (0, height - int(scaled_loss[0] * height))
for i, l in enumerate(scaled_loss[1:], start=1):
x = int(i / (steps - 1) * width)
y = height - int(l * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
return img
def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None):
if result is None:
result = []
elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
result.append(model)
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
return result
name = name or "root"
for next_name, child in model.named_children():
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
return result
def patch(m):
if not hasattr(m, "forward"):
return
org_forward = m.forward
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(
fwd, args, kwargs, use_reentrant=False
)
m.org_forward = org_forward
m.forward = checkpointing_fwd
def unpatch(m):
if hasattr(m, "org_forward"):
m.forward = m.org_forward
del m.org_forward
class TrainLoraNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}),
"latents": (
"LATENT",
{
"tooltip": "The Latents to use for training, serve as dataset/input of the model."
},
),
"positive": (
IO.CONDITIONING,
{"tooltip": "The positive conditioning to use for training."},
),
"batch_size": (
IO.INT,
{
"default": 1,
"min": 1,
"max": 10000,
"step": 1,
"tooltip": "The batch size to use for training.",
},
),
"steps": (
IO.INT,
{
"default": 16,
"min": 1,
"max": 100000,
"tooltip": "The number of steps to train the LoRA for.",
},
),
"learning_rate": (
IO.FLOAT,
{
"default": 0.0005,
"min": 0.0000001,
"max": 1.0,
"step": 0.000001,
"tooltip": "The learning rate to use for training.",
},
),
"rank": (
IO.INT,
{
"default": 8,
"min": 1,
"max": 128,
"tooltip": "The rank of the LoRA layers.",
},
),
"optimizer": (
["AdamW", "Adam", "SGD", "RMSprop"],
{
"default": "AdamW",
"tooltip": "The optimizer to use for training.",
},
),
"loss_function": (
["MSE", "L1", "Huber", "SmoothL1"],
{
"default": "MSE",
"tooltip": "The loss function to use for training.",
},
),
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
},
),
"training_dtype": (
["bf16", "fp32"],
{"default": "bf16", "tooltip": "The dtype to use for training."},
),
"lora_dtype": (
["bf16", "fp32"],
{"default": "bf16", "tooltip": "The dtype to use for lora."},
),
"existing_lora": (
folder_paths.get_filename_list("loras") + ["[None]"],
{
"default": "[None]",
"tooltip": "The existing LoRA to append to. Set to None for new LoRA.",
},
),
},
}
RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT)
RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps")
FUNCTION = "train"
CATEGORY = "training"
EXPERIMENTAL = True
def train(
self,
model,
latents,
positive,
batch_size,
steps,
learning_rate,
rank,
optimizer,
loss_function,
seed,
training_dtype,
lora_dtype,
existing_lora,
):
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
latents = latents["samples"].to(dtype)
num_images = latents.shape[0]
with torch.inference_mode(False):
lora_sd = {}
generator = torch.Generator()
generator.manual_seed(seed)
# Load existing LoRA weights if provided
existing_weights = {}
existing_steps = 0
if existing_lora != "[None]":
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
all_weight_adapters = []
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
key = "{}.weight".format(n)
shape = m.weight.shape
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(
f"{key}.dora_scale", None
)
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
n, existing_weights, alpha, dora_scale
)
if existing_adapter is not None:
break
else:
# If no existing adapter found, use LoRA
# We will add algo option in the future
existing_adapter = None
adapter_cls = adapters[0]
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
m.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_sd[f"{n}.{name}"] = parameter
mp.add_weight_wrapper(key, train_adapter)
all_weight_adapters.append(train_adapter)
else:
diff = torch.nn.Parameter(
torch.zeros(
m.weight.shape, dtype=lora_dtype, requires_grad=True
)
)
diff_module = BiasDiff(diff)
mp.add_weight_wrapper(key, BiasDiff(diff))
all_weight_adapters.append(diff_module)
lora_sd["{}.diff".format(n)] = diff
if hasattr(m, "bias") and m.bias is not None:
key = "{}.bias".format(n)
bias = torch.nn.Parameter(
torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True)
)
bias_module = BiasDiff(bias)
lora_sd["{}.diff_b".format(n)] = bias
mp.add_weight_wrapper(key, BiasDiff(bias))
all_weight_adapters.append(bias_module)
if optimizer == "Adam":
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
elif optimizer == "AdamW":
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
elif optimizer == "SGD":
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
elif optimizer == "RMSprop":
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
# Setup loss function based on selection
if loss_function == "MSE":
criterion = torch.nn.MSELoss()
elif loss_function == "L1":
criterion = torch.nn.L1Loss()
elif loss_function == "Huber":
criterion = torch.nn.HuberLoss()
elif loss_function == "SmoothL1":
criterion = torch.nn.SmoothL1Loss()
# setup models
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
patch(m)
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
# Setup sampler and guider like in test script
loss_map = {"loss": []}
def loss_callback(loss):
loss_map["loss"].append(loss)
pbar.set_postfix({"loss": f"{loss:.4f}"})
train_sampler = TrainSampler(
criterion, optimizer, loss_callback=loss_callback
)
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input
ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced()
# yoland: this currently resize to the first image in the dataset
# Training loop
torch.cuda.empty_cache()
try:
for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
# Generate random sigma
sigma = mp.model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
sigma = torch.tensor([sigma])
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
indices = torch.randperm(num_images)[:batch_size]
ss.sample(
noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()}
)
finally:
for m in mp.model.modules():
unpatch(m)
del ss, train_sampler, optimizer
torch.cuda.empty_cache()
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)
return (mp, lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader:
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
"lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
}
}
RETURN_TYPES = ("MODEL",)
OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
FUNCTION = "load_lora_model"
CATEGORY = "loaders"
DESCRIPTION = "Load Trained LoRA weights from Train LoRA node."
EXPERIMENTAL = True
def load_lora_model(self, model, lora, strength_model):
if strength_model == 0:
return (model, )
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
return (model_lora, )
class SaveLoRA:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora": (
IO.LORA_MODEL,
{
"tooltip": "The LoRA model to save. Do not use the model with LoRA layers."
},
),
"prefix": (
"STRING",
{
"default": "trained_lora",
"tooltip": "The prefix to use for the saved LoRA file.",
},
),
},
"optional": {
"steps": (
IO.INT,
{
"forceInput": True,
"tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
},
),
},
}
RETURN_TYPES = ()
FUNCTION = "save"
CATEGORY = "loaders"
EXPERIMENTAL = True
OUTPUT_NODE = True
def save(self, lora, prefix, steps=None):
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
if steps is None:
output_file = f"models/loras/{prefix}_{date}_lora.safetensors"
else:
output_file = f"models/loras/{prefix}_{steps}_steps_{date}_lora.safetensors"
safetensors.torch.save_file(lora, output_file)
return {}
class LossGraphNode:
def __init__(self):
self.output_dir = folder_paths.get_temp_directory()
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"loss": (IO.LOSS_MAP, {"default": {}}),
"filename_prefix": (IO.STRING, {"default": "loss_graph"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "plot_loss"
OUTPUT_NODE = True
CATEGORY = "training"
EXPERIMENTAL = True
DESCRIPTION = "Plots the loss graph and saves it to the output directory."
def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
loss_values = loss["loss"]
width, height = 800, 480
margin = 40
img = Image.new(
"RGB", (width + margin, height + margin), "white"
) # Extend canvas
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_values), max(loss_values)
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values]
steps = len(loss_values)
prev_point = (margin, height - int(scaled_loss[0] * height))
for i, l in enumerate(scaled_loss[1:], start=1):
x = margin + int(i / steps * width) # Scale X properly
y = height - int(l * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
draw.line(
[(margin, height), (width + margin, height)], fill="black", width=2
) # X-axis
font = None
try:
font = ImageFont.truetype("arial.ttf", 12)
except IOError:
font = ImageFont.load_default()
# Add axis labels
draw.text((5, height // 2), "Loss", font=font, fill="black")
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
# Add min/max loss values
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
draw.text(
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
)
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
img.save(
os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"),
pnginfo=metadata,
)
return {
"ui": {
"images": [
{
"filename": f"{filename_prefix}_{date}.png",
"subfolder": "",
"type": "temp",
}
]
}
}
NODE_CLASS_MAPPINGS = {
"TrainLoraNode": TrainLoraNode,
"SaveLoRANode": SaveLoRA,
"LoraModelLoader": LoraModelLoader,
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
"LossGraphNode": LossGraphNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"TrainLoraNode": "Train LoRA",
"SaveLoRANode": "Save LoRA Weights",
"LoraModelLoader": "Load LoRA Model",
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
"LossGraphNode": "Plot Loss Graph",
}

View File

@ -1,23 +1,35 @@
import sys
import copy
import logging
import threading
import heapq
import inspect
import logging
import sys
import threading
import time
import traceback
from enum import Enum
import inspect
from typing import List, Literal, NamedTuple, Optional
import torch
import nodes
import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
import nodes
from comfy_execution.caching import (
CacheKeySetID,
CacheKeySetInputSignature,
DependencyAwareCache,
HierarchicalCache,
LRUCache,
)
from comfy_execution.graph import (
DynamicPrompt,
ExecutionBlocker,
ExecutionList,
get_input_info,
)
from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input
class ExecutionResult(Enum):
SUCCESS = 0
FAILURE = 1

View File

@ -276,6 +276,9 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str])
def get_full_path(folder_name: str, filename: str) -> str | None:
"""
Get the full path of a file in a folder, has to be a file
"""
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in folder_names_and_paths:
@ -293,6 +296,9 @@ def get_full_path(folder_name: str, filename: str) -> str | None:
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
"""
Get the full path of a file in a folder, has to be a file
"""
full_path = get_full_path(folder_name, filename)
if full_path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
@ -394,3 +400,26 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
return full_output_folder, filename, counter, subfolder, filename_prefix
def get_input_subfolders() -> list[str]:
"""Returns a list of all subfolder paths in the input directory, recursively.
Returns:
List of folder paths relative to the input directory, excluding the root directory
"""
input_dir = get_input_directory()
folders = []
try:
if not os.path.exists(input_dir):
return []
for root, dirs, _ in os.walk(input_dir):
rel_path = os.path.relpath(root, input_dir)
if rel_path != ".": # Only include non-root directories
# Normalize path separators to forward slashes
folders.append(rel_path.replace(os.sep, '/'))
return sorted(folders)
except FileNotFoundError:
return []

View File

@ -2231,6 +2231,7 @@ def init_builtin_extra_nodes():
"nodes_model_downscale.py",
"nodes_images.py",
"nodes_video_model.py",
"nodes_train.py",
"nodes_sag.py",
"nodes_perpneg.py",
"nodes_stable3d.py",

View File

@ -0,0 +1,51 @@
import pytest
import os
import tempfile
from folder_paths import get_input_subfolders, set_input_directory
@pytest.fixture(scope="module")
def mock_folder_structure():
with tempfile.TemporaryDirectory() as temp_dir:
# Create a nested folder structure
folders = [
"folder1",
"folder1/subfolder1",
"folder1/subfolder2",
"folder2",
"folder2/deep",
"folder2/deep/nested",
"empty_folder"
]
# Create the folders
for folder in folders:
os.makedirs(os.path.join(temp_dir, folder))
# Add some files to test they're not included
with open(os.path.join(temp_dir, "root_file.txt"), "w") as f:
f.write("test")
with open(os.path.join(temp_dir, "folder1", "test.txt"), "w") as f:
f.write("test")
set_input_directory(temp_dir)
yield temp_dir
def test_gets_all_folders(mock_folder_structure):
folders = get_input_subfolders()
expected = ["folder1", "folder1/subfolder1", "folder1/subfolder2",
"folder2", "folder2/deep", "folder2/deep/nested", "empty_folder"]
assert sorted(folders) == sorted(expected)
def test_handles_nonexistent_input_directory():
with tempfile.TemporaryDirectory() as temp_dir:
nonexistent = os.path.join(temp_dir, "nonexistent")
set_input_directory(nonexistent)
assert get_input_subfolders() == []
def test_empty_input_directory():
with tempfile.TemporaryDirectory() as temp_dir:
set_input_directory(temp_dir)
assert get_input_subfolders() == [] # Empty since we don't include root