diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 470eb9fdb..071b98332 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -37,6 +37,8 @@ class IO(StrEnum): CONTROL_NET = "CONTROL_NET" VAE = "VAE" MODEL = "MODEL" + LORA_MODEL = "LORA_MODEL" + LOSS_MAP = "LOSS_MAP" CLIP_VISION = "CLIP_VISION" CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT" STYLE_MODEL = "STYLE_MODEL" diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 2cb77d85d..35d2270ee 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -753,7 +753,7 @@ class BasicTransformerBlock(nn.Module): for p in patch: n = p(n, extra_options) - x += n + x = n + x if "middle_patch" in transformer_patches: patch = transformer_patches["middle_patch"] for p in patch: @@ -793,12 +793,12 @@ class BasicTransformerBlock(nn.Module): for p in patch: n = p(n, extra_options) - x += n + x = n + x if self.is_res: x_skip = x x = self.ff(self.norm3(x)) if self.is_res: - x += x_skip + x = x_skip + x return x diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b7cb12dfc..b1d6d4395 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -17,23 +17,26 @@ """ from __future__ import annotations -from typing import Optional, Callable -import torch + +import collections import copy import inspect import logging -import uuid -import collections import math +import uuid +from typing import Callable, Optional + +import torch -import comfy.utils import comfy.float -import comfy.model_management -import comfy.lora import comfy.hooks +import comfy.lora +import comfy.model_management import comfy.patcher_extension -from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection +import comfy.utils from comfy.comfy_types import UnetWrapperFunction +from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP + def string_to_seed(data): crc = 0xFFFFFFFF diff --git a/comfy/sd.py b/comfy/sd.py index e98a3aa87..cd13ab5f0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1081,7 +1081,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format +def load_diffusion_model_state_dict(sd, model_options={}): + """ + Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. + + Args: + sd (dict): State dictionary containing model weights and configuration + model_options (dict, optional): Additional options for model loading. Supports: + - dtype: Override model data type + - custom_operations: Custom model operations + - fp8_optimizations: Enable FP8 optimizations + + Returns: + ModelPatcher: A wrapped model instance that handles device management and weight loading. + Returns None if the model configuration cannot be detected. + + The function: + 1. Detects and handles different model formats (regular, diffusers, mmdit) + 2. Configures model dtype based on parameters and device capabilities + 3. Handles weight conversion and device placement + 4. Manages model optimization settings + 5. Loads weights and returns a device-managed model instance + """ dtype = model_options.get("dtype", None) #Allow loading unets from checkpoint files diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py index d2a1d0151..560b82be3 100644 --- a/comfy/weight_adapter/__init__.py +++ b/comfy/weight_adapter/__init__.py @@ -1,4 +1,4 @@ -from .base import WeightAdapterBase +from .base import WeightAdapterBase, WeightAdapterTrainBase from .lora import LoRAAdapter from .loha import LoHaAdapter from .lokr import LoKrAdapter @@ -15,3 +15,9 @@ adapters: list[type[WeightAdapterBase]] = [ OFTAdapter, BOFTAdapter, ] + +__all__ = [ + "WeightAdapterBase", + "WeightAdapterTrainBase", + "adapters" +] + [a.__name__ for a in adapters] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index 29873519d..b5c7db423 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -12,12 +12,20 @@ class WeightAdapterBase: weights: list[torch.Tensor] @classmethod - def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]: + def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]: raise NotImplementedError def to_train(self) -> "WeightAdapterTrainBase": raise NotImplementedError + @classmethod + def create_train(cls, weight, *args) -> "WeightAdapterTrainBase": + """ + weight: The original weight tensor to be modified. + *args: Additional arguments for configuration, such as rank, alpha etc. + """ + raise NotImplementedError + def calculate_weight( self, weight, @@ -33,10 +41,22 @@ class WeightAdapterBase: class WeightAdapterTrainBase(nn.Module): + # We follow the scheme of PR #7032 def __init__(self): super().__init__() - # [TODO] Collaborate with LoRA training PR #7032 + def __call__(self, w): + """ + w: The original weight tensor to be modified. + """ + raise NotImplementedError + + def passive_memory_usage(self): + raise NotImplementedError("passive_memory_usage is not implemented") + + def move_to(self, device): + self.to(device) + return self.passive_memory_usage() def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): @@ -102,3 +122,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten padded_tensor[new_slices] = tensor[orig_slices] return padded_tensor + + +def tucker_weight_from_conv(up, down, mid): + up = up.reshape(up.size(0), up.size(1)) + down = down.reshape(down.size(0), down.size(1)) + return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down) + + +def tucker_weight(wa, wb, t): + temp = torch.einsum("i j ..., j r -> i r ...", t, wb) + return torch.einsum("i j ..., i r -> r j ...", temp, wa) diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index b2e623924..729dbd9e6 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -3,7 +3,56 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape +from .base import ( + WeightAdapterBase, + WeightAdapterTrainBase, + weight_decompose, + pad_tensor_to_shape, + tucker_weight_from_conv, +) + + +class LoraDiff(WeightAdapterTrainBase): + def __init__(self, weights): + super().__init__() + mat1, mat2, alpha, mid, dora_scale, reshape = weights + out_dim, rank = mat1.shape[0], mat1.shape[1] + rank, in_dim = mat2.shape[0], mat2.shape[1] + if mid is not None: + convdim = mid.ndim - 2 + layer = ( + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d + )[convdim] + else: + layer = torch.nn.Linear + self.lora_up = layer(rank, out_dim, bias=False) + self.lora_down = layer(in_dim, rank, bias=False) + self.lora_up.weight.data.copy_(mat1) + self.lora_down.weight.data.copy_(mat2) + if mid is not None: + self.lora_mid = layer(mid, rank, bias=False) + self.lora_mid.weight.data.copy_(mid) + else: + self.lora_mid = None + self.rank = rank + self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) + + def __call__(self, w): + org_dtype = w.dtype + if self.lora_mid is None: + diff = self.lora_up.weight @ self.lora_down.weight + else: + diff = tucker_weight_from_conv( + self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight + ) + scale = self.alpha / self.rank + weight = w + scale * diff.reshape(w.shape) + return weight.to(org_dtype) + + def passive_memory_usage(self): + return sum(param.numel() * param.element_size() for param in self.parameters()) class LoRAAdapter(WeightAdapterBase): @@ -13,6 +62,21 @@ class LoRAAdapter(WeightAdapterBase): self.loaded_keys = loaded_keys self.weights = weights + @classmethod + def create_train(cls, weight, rank=1, alpha=1.0): + out_dim = weight.shape[0] + in_dim = weight.shape[1:].numel() + mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) + mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) + torch.nn.init.constant_(mat2, 0.0) + return LoraDiff( + (mat1, mat2, alpha, None, None, None) + ) + + def to_train(self): + return LoraDiff(self.weights) + @classmethod def load( cls, diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py new file mode 100644 index 000000000..17a03a1e0 --- /dev/null +++ b/comfy_extras/nodes_train.py @@ -0,0 +1,705 @@ +import datetime +import json +import logging +import os + +import numpy as np +import safetensors +import torch +from PIL import Image, ImageDraw, ImageFont +from PIL.PngImagePlugin import PngInfo +import torch.utils.checkpoint +import tqdm + +import comfy.samplers +import comfy.sd +import comfy.utils +import comfy.model_management +import comfy_extras.nodes_custom_sampler +import folder_paths +import node_helpers +from comfy.cli_args import args +from comfy.comfy_types.node_typing import IO +from comfy.weight_adapter import adapters + + +class TrainSampler(comfy.samplers.Sampler): + + def __init__(self, loss_fn, optimizer, loss_callback=None): + self.loss_fn = loss_fn + self.optimizer = optimizer + self.loss_callback = loss_callback + + def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + self.optimizer.zero_grad() + noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False) + latent = model_wrap.inner_model.model_sampling.noise_scaling( + torch.zeros_like(sigmas), + torch.zeros_like(noise, requires_grad=True), + latent_image, + False + ) + + # Ensure model is in training mode and computing gradients + # x0 pred + denoised = model_wrap(noise, sigmas, **extra_args) + try: + loss = self.loss_fn(denoised, latent.clone()) + except RuntimeError as e: + if "does not require grad and does not have a grad_fn" in str(e): + logging.info("WARNING: This is likely due to the model is loaded in inference mode.") + loss.backward() + if self.loss_callback: + self.loss_callback(loss.item()) + + self.optimizer.step() + # torch.cuda.memory._dump_snapshot("trainn.pickle") + # torch.cuda.memory._record_memory_history(enabled=None) + return torch.zeros_like(latent_image) + + +class BiasDiff(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.bias = bias + + def __call__(self, b): + org_dtype = b.dtype + return (b.to(self.bias) + self.bias).to(org_dtype) + + def passive_memory_usage(self): + return self.bias.nelement() * self.bias.element_size() + + def move_to(self, device): + self.to(device=device) + return self.passive_memory_usage() + + +def load_and_process_images(image_files, input_dir, resize_method="None"): + """Utility function to load and process a list of images. + + Args: + image_files: List of image filenames + input_dir: Base directory containing the images + resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") + + Returns: + torch.Tensor: Batch of processed images + """ + if not image_files: + raise ValueError("No valid images found in input") + + output_images = [] + w, h = None, None + + for file in image_files: + image_path = os.path.join(input_dir, file) + img = node_helpers.pillow(Image.open, image_path) + + if img.mode == "I": + img = img.point(lambda i: i * (1 / 255)) + img = img.convert("RGB") + + if w is None and h is None: + w, h = img.size[0], img.size[1] + + # Resize image to first image + if img.size[0] != w or img.size[1] != h: + if resize_method == "Stretch": + img = img.resize((w, h), Image.Resampling.LANCZOS) + elif resize_method == "Crop": + img = img.crop((0, 0, w, h)) + elif resize_method == "Pad": + img = img.resize((w, h), Image.Resampling.LANCZOS) + elif resize_method == "None": + raise ValueError( + "Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images." + ) + + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] + output_images.append(img_tensor) + + return torch.cat(output_images, dim=0) + + +class LoadImageSetNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "images": ( + [ + f + for f in os.listdir(folder_paths.get_input_directory()) + if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff")) + ], + {"image_upload": True, "allow_batch": True}, + ) + }, + "optional": { + "resize_method": ( + ["None", "Stretch", "Crop", "Pad"], + {"default": "None"}, + ), + }, + } + + INPUT_IS_LIST = True + RETURN_TYPES = ("IMAGE",) + FUNCTION = "load_images" + CATEGORY = "loaders" + EXPERIMENTAL = True + DESCRIPTION = "Loads a batch of images from a directory for training." + + @classmethod + def VALIDATE_INPUTS(s, images, resize_method): + filenames = images[0] if isinstance(images[0], list) else images + + for image in filenames: + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) + return True + + def load_images(self, input_files, resize_method): + input_dir = folder_paths.get_input_directory() + valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"] + image_files = [ + f + for f in input_files + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, input_dir, resize_method) + return (output_tensor,) + + +class LoadImageSetFromFolderNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}) + }, + "optional": { + "resize_method": ( + ["None", "Stretch", "Crop", "Pad"], + {"default": "None"}, + ), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "load_images" + CATEGORY = "loaders" + EXPERIMENTAL = True + DESCRIPTION = "Loads a batch of images from a directory for training." + + def load_images(self, folder, resize_method): + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + image_files = [ + f + for f in os.listdir(sub_input_dir) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method) + return (output_tensor,) + + +def draw_loss_graph(loss_map, steps): + width, height = 500, 300 + img = Image.new("RGB", (width, height), "white") + draw = ImageDraw.Draw(img) + + min_loss, max_loss = min(loss_map.values()), max(loss_map.values()) + scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()] + + prev_point = (0, height - int(scaled_loss[0] * height)) + for i, l in enumerate(scaled_loss[1:], start=1): + x = int(i / (steps - 1) * width) + y = height - int(l * height) + draw.line([prev_point, (x, y)], fill="blue", width=2) + prev_point = (x, y) + + return img + + +def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None): + if result is None: + result = [] + elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): + result.append(model) + logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") + return result + name = name or "root" + for next_name, child in model.named_children(): + find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}") + return result + + +def patch(m): + if not hasattr(m, "forward"): + return + org_forward = m.forward + def fwd(args, kwargs): + return org_forward(*args, **kwargs) + def checkpointing_fwd(*args, **kwargs): + return torch.utils.checkpoint.checkpoint( + fwd, args, kwargs, use_reentrant=False + ) + m.org_forward = org_forward + m.forward = checkpointing_fwd + + +def unpatch(m): + if hasattr(m, "org_forward"): + m.forward = m.org_forward + del m.org_forward + + +class TrainLoraNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}), + "latents": ( + "LATENT", + { + "tooltip": "The Latents to use for training, serve as dataset/input of the model." + }, + ), + "positive": ( + IO.CONDITIONING, + {"tooltip": "The positive conditioning to use for training."}, + ), + "batch_size": ( + IO.INT, + { + "default": 1, + "min": 1, + "max": 10000, + "step": 1, + "tooltip": "The batch size to use for training.", + }, + ), + "steps": ( + IO.INT, + { + "default": 16, + "min": 1, + "max": 100000, + "tooltip": "The number of steps to train the LoRA for.", + }, + ), + "learning_rate": ( + IO.FLOAT, + { + "default": 0.0005, + "min": 0.0000001, + "max": 1.0, + "step": 0.000001, + "tooltip": "The learning rate to use for training.", + }, + ), + "rank": ( + IO.INT, + { + "default": 8, + "min": 1, + "max": 128, + "tooltip": "The rank of the LoRA layers.", + }, + ), + "optimizer": ( + ["AdamW", "Adam", "SGD", "RMSprop"], + { + "default": "AdamW", + "tooltip": "The optimizer to use for training.", + }, + ), + "loss_function": ( + ["MSE", "L1", "Huber", "SmoothL1"], + { + "default": "MSE", + "tooltip": "The loss function to use for training.", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", + }, + ), + "training_dtype": ( + ["bf16", "fp32"], + {"default": "bf16", "tooltip": "The dtype to use for training."}, + ), + "lora_dtype": ( + ["bf16", "fp32"], + {"default": "bf16", "tooltip": "The dtype to use for lora."}, + ), + "existing_lora": ( + folder_paths.get_filename_list("loras") + ["[None]"], + { + "default": "[None]", + "tooltip": "The existing LoRA to append to. Set to None for new LoRA.", + }, + ), + }, + } + + RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT) + RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps") + FUNCTION = "train" + CATEGORY = "training" + EXPERIMENTAL = True + + def train( + self, + model, + latents, + positive, + batch_size, + steps, + learning_rate, + rank, + optimizer, + loss_function, + seed, + training_dtype, + lora_dtype, + existing_lora, + ): + mp = model.clone() + dtype = node_helpers.string_to_torch_dtype(training_dtype) + lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) + mp.set_model_compute_dtype(dtype) + + latents = latents["samples"].to(dtype) + num_images = latents.shape[0] + + with torch.inference_mode(False): + lora_sd = {} + generator = torch.Generator() + generator.manual_seed(seed) + + # Load existing LoRA weights if provided + existing_weights = {} + existing_steps = 0 + if existing_lora != "[None]": + lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora) + # Extract steps from filename like "trained_lora_10_steps_20250225_203716" + existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1]) + if lora_path: + existing_weights = comfy.utils.load_torch_file(lora_path) + + all_weight_adapters = [] + for n, m in mp.model.named_modules(): + if hasattr(m, "weight_function"): + if m.weight is not None: + key = "{}.weight".format(n) + shape = m.weight.shape + if len(shape) >= 2: + alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) + dora_scale = existing_weights.get( + f"{key}.dora_scale", None + ) + for adapter_cls in adapters: + existing_adapter = adapter_cls.load( + n, existing_weights, alpha, dora_scale + ) + if existing_adapter is not None: + break + else: + # If no existing adapter found, use LoRA + # We will add algo option in the future + existing_adapter = None + adapter_cls = adapters[0] + + if existing_adapter is not None: + train_adapter = existing_adapter.to_train().to(lora_dtype) + else: + # Use LoRA with alpha=1.0 by default + train_adapter = adapter_cls.create_train( + m.weight, rank=rank, alpha=1.0 + ).to(lora_dtype) + for name, parameter in train_adapter.named_parameters(): + lora_sd[f"{n}.{name}"] = parameter + + mp.add_weight_wrapper(key, train_adapter) + all_weight_adapters.append(train_adapter) + else: + diff = torch.nn.Parameter( + torch.zeros( + m.weight.shape, dtype=lora_dtype, requires_grad=True + ) + ) + diff_module = BiasDiff(diff) + mp.add_weight_wrapper(key, BiasDiff(diff)) + all_weight_adapters.append(diff_module) + lora_sd["{}.diff".format(n)] = diff + if hasattr(m, "bias") and m.bias is not None: + key = "{}.bias".format(n) + bias = torch.nn.Parameter( + torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True) + ) + bias_module = BiasDiff(bias) + lora_sd["{}.diff_b".format(n)] = bias + mp.add_weight_wrapper(key, BiasDiff(bias)) + all_weight_adapters.append(bias_module) + + if optimizer == "Adam": + optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate) + elif optimizer == "AdamW": + optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate) + elif optimizer == "SGD": + optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate) + elif optimizer == "RMSprop": + optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate) + + # Setup loss function based on selection + if loss_function == "MSE": + criterion = torch.nn.MSELoss() + elif loss_function == "L1": + criterion = torch.nn.L1Loss() + elif loss_function == "Huber": + criterion = torch.nn.HuberLoss() + elif loss_function == "SmoothL1": + criterion = torch.nn.SmoothL1Loss() + + # setup models + for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): + patch(m) + comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) + + # Setup sampler and guider like in test script + loss_map = {"loss": []} + def loss_callback(loss): + loss_map["loss"].append(loss) + pbar.set_postfix({"loss": f"{loss:.4f}"}) + train_sampler = TrainSampler( + criterion, optimizer, loss_callback=loss_callback + ) + guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) + guider.set_conds(positive) # Set conditioning from input + ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced() + + # yoland: this currently resize to the first image in the dataset + + # Training loop + torch.cuda.empty_cache() + try: + for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): + # Generate random sigma + sigma = mp.model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + sigma = torch.tensor([sigma]) + + noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed) + + indices = torch.randperm(num_images)[:batch_size] + ss.sample( + noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()} + ) + finally: + for m in mp.model.modules(): + unpatch(m) + del ss, train_sampler, optimizer + torch.cuda.empty_cache() + + for adapter in all_weight_adapters: + adapter.requires_grad_(False) + + for param in lora_sd: + lora_sd[param] = lora_sd[param].to(lora_dtype) + + return (mp, lora_sd, loss_map, steps + existing_steps) + + +class LoraModelLoader: + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), + "lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), + } + } + + RETURN_TYPES = ("MODEL",) + OUTPUT_TOOLTIPS = ("The modified diffusion model.",) + FUNCTION = "load_lora_model" + + CATEGORY = "loaders" + DESCRIPTION = "Load Trained LoRA weights from Train LoRA node." + EXPERIMENTAL = True + + def load_lora_model(self, model, lora, strength_model): + if strength_model == 0: + return (model, ) + + model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) + return (model_lora, ) + + +class SaveLoRA: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "lora": ( + IO.LORA_MODEL, + { + "tooltip": "The LoRA model to save. Do not use the model with LoRA layers." + }, + ), + "prefix": ( + "STRING", + { + "default": "trained_lora", + "tooltip": "The prefix to use for the saved LoRA file.", + }, + ), + }, + "optional": { + "steps": ( + IO.INT, + { + "forceInput": True, + "tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.", + }, + ), + }, + } + + RETURN_TYPES = () + FUNCTION = "save" + CATEGORY = "loaders" + EXPERIMENTAL = True + OUTPUT_NODE = True + + def save(self, lora, prefix, steps=None): + date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + if steps is None: + output_file = f"models/loras/{prefix}_{date}_lora.safetensors" + else: + output_file = f"models/loras/{prefix}_{steps}_steps_{date}_lora.safetensors" + safetensors.torch.save_file(lora, output_file) + return {} + + +class LossGraphNode: + def __init__(self): + self.output_dir = folder_paths.get_temp_directory() + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "loss": (IO.LOSS_MAP, {"default": {}}), + "filename_prefix": (IO.STRING, {"default": "loss_graph"}), + }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "plot_loss" + OUTPUT_NODE = True + CATEGORY = "training" + EXPERIMENTAL = True + DESCRIPTION = "Plots the loss graph and saves it to the output directory." + + def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None): + loss_values = loss["loss"] + width, height = 800, 480 + margin = 40 + + img = Image.new( + "RGB", (width + margin, height + margin), "white" + ) # Extend canvas + draw = ImageDraw.Draw(img) + + min_loss, max_loss = min(loss_values), max(loss_values) + scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values] + + steps = len(loss_values) + + prev_point = (margin, height - int(scaled_loss[0] * height)) + for i, l in enumerate(scaled_loss[1:], start=1): + x = margin + int(i / steps * width) # Scale X properly + y = height - int(l * height) + draw.line([prev_point, (x, y)], fill="blue", width=2) + prev_point = (x, y) + + draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis + draw.line( + [(margin, height), (width + margin, height)], fill="black", width=2 + ) # X-axis + + font = None + try: + font = ImageFont.truetype("arial.ttf", 12) + except IOError: + font = ImageFont.load_default() + + # Add axis labels + draw.text((5, height // 2), "Loss", font=font, fill="black") + draw.text((width // 2, height + 10), "Steps", font=font, fill="black") + + # Add min/max loss values + draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black") + draw.text( + (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" + ) + + metadata = None + if not args.disable_metadata: + metadata = PngInfo() + if prompt is not None: + metadata.add_text("prompt", json.dumps(prompt)) + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata.add_text(x, json.dumps(extra_pnginfo[x])) + + date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + img.save( + os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"), + pnginfo=metadata, + ) + return { + "ui": { + "images": [ + { + "filename": f"{filename_prefix}_{date}.png", + "subfolder": "", + "type": "temp", + } + ] + } + } + + +NODE_CLASS_MAPPINGS = { + "TrainLoraNode": TrainLoraNode, + "SaveLoRANode": SaveLoRA, + "LoraModelLoader": LoraModelLoader, + "LoadImageSetFromFolderNode": LoadImageSetFromFolderNode, + "LossGraphNode": LossGraphNode, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "TrainLoraNode": "Train LoRA", + "SaveLoRANode": "Save LoRA Weights", + "LoraModelLoader": "Load LoRA Model", + "LoadImageSetFromFolderNode": "Load Image Dataset from Folder", + "LossGraphNode": "Plot Loss Graph", +} diff --git a/execution.py b/execution.py index 15ff7567c..d0012afda 100644 --- a/execution.py +++ b/execution.py @@ -1,23 +1,35 @@ -import sys import copy -import logging -import threading import heapq +import inspect +import logging +import sys +import threading import time import traceback from enum import Enum -import inspect from typing import List, Literal, NamedTuple, Optional import torch -import nodes import comfy.model_management -from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker -from comfy_execution.graph_utils import is_link, GraphBuilder -from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID +import nodes +from comfy_execution.caching import ( + CacheKeySetID, + CacheKeySetInputSignature, + DependencyAwareCache, + HierarchicalCache, + LRUCache, +) +from comfy_execution.graph import ( + DynamicPrompt, + ExecutionBlocker, + ExecutionList, + get_input_info, +) +from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.validation import validate_node_input + class ExecutionResult(Enum): SUCCESS = 0 FAILURE = 1 diff --git a/folder_paths.py b/folder_paths.py index f0b3fd103..9ec952940 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -276,6 +276,9 @@ def filter_files_extensions(files: Collection[str], extensions: Collection[str]) def get_full_path(folder_name: str, filename: str) -> str | None: + """ + Get the full path of a file in a folder, has to be a file + """ global folder_names_and_paths folder_name = map_legacy(folder_name) if folder_name not in folder_names_and_paths: @@ -293,6 +296,9 @@ def get_full_path(folder_name: str, filename: str) -> str | None: def get_full_path_or_raise(folder_name: str, filename: str) -> str: + """ + Get the full path of a file in a folder, has to be a file + """ full_path = get_full_path(folder_name, filename) if full_path is None: raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.") @@ -394,3 +400,26 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im os.makedirs(full_output_folder, exist_ok=True) counter = 1 return full_output_folder, filename, counter, subfolder, filename_prefix + +def get_input_subfolders() -> list[str]: + """Returns a list of all subfolder paths in the input directory, recursively. + + Returns: + List of folder paths relative to the input directory, excluding the root directory + """ + input_dir = get_input_directory() + folders = [] + + try: + if not os.path.exists(input_dir): + return [] + + for root, dirs, _ in os.walk(input_dir): + rel_path = os.path.relpath(root, input_dir) + if rel_path != ".": # Only include non-root directories + # Normalize path separators to forward slashes + folders.append(rel_path.replace(os.sep, '/')) + + return sorted(folders) + except FileNotFoundError: + return [] diff --git a/nodes.py b/nodes.py index 8e5b47b37..bfc342275 100644 --- a/nodes.py +++ b/nodes.py @@ -2231,6 +2231,7 @@ def init_builtin_extra_nodes(): "nodes_model_downscale.py", "nodes_images.py", "nodes_video_model.py", + "nodes_train.py", "nodes_sag.py", "nodes_perpneg.py", "nodes_stable3d.py", diff --git a/tests-unit/folder_paths_test/misc_test.py b/tests-unit/folder_paths_test/misc_test.py new file mode 100644 index 000000000..fcf667453 --- /dev/null +++ b/tests-unit/folder_paths_test/misc_test.py @@ -0,0 +1,51 @@ +import pytest +import os +import tempfile +from folder_paths import get_input_subfolders, set_input_directory + +@pytest.fixture(scope="module") +def mock_folder_structure(): + with tempfile.TemporaryDirectory() as temp_dir: + # Create a nested folder structure + folders = [ + "folder1", + "folder1/subfolder1", + "folder1/subfolder2", + "folder2", + "folder2/deep", + "folder2/deep/nested", + "empty_folder" + ] + + # Create the folders + for folder in folders: + os.makedirs(os.path.join(temp_dir, folder)) + + # Add some files to test they're not included + with open(os.path.join(temp_dir, "root_file.txt"), "w") as f: + f.write("test") + with open(os.path.join(temp_dir, "folder1", "test.txt"), "w") as f: + f.write("test") + + set_input_directory(temp_dir) + yield temp_dir + + +def test_gets_all_folders(mock_folder_structure): + folders = get_input_subfolders() + expected = ["folder1", "folder1/subfolder1", "folder1/subfolder2", + "folder2", "folder2/deep", "folder2/deep/nested", "empty_folder"] + assert sorted(folders) == sorted(expected) + + +def test_handles_nonexistent_input_directory(): + with tempfile.TemporaryDirectory() as temp_dir: + nonexistent = os.path.join(temp_dir, "nonexistent") + set_input_directory(nonexistent) + assert get_input_subfolders() == [] + + +def test_empty_input_directory(): + with tempfile.TemporaryDirectory() as temp_dir: + set_input_directory(temp_dir) + assert get_input_subfolders() == [] # Empty since we don't include root